## Multi-head attention transformer
### Encoder and Decoder
### (With masking)


### Inilialsing an Encoder-Decoder transformer with different modules from torch (such as nn.Transformer, nn.Linear, etc)

In [106]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import numpy as np



In [107]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()

        self.encoding = torch.zeros(max_len, d_model)

        # Create a tensor with positions [0, 1, 2, ..., max_len-1]
        position = torch.arange(0, max_len).unsqueeze(1).float()

        # Compute the division term based on the dimension of the model
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))

        # Apply sine to even indices in the positional encoding
        self.encoding[:, 0::2] = torch.sin(position * div_term)

        # Apply cosine to odd indices in the positional encoding
        self.encoding[:, 1::2] = torch.cos(position * div_term)

        # Add an extra dimension to match the batch size (1, max_len, d_model)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):

        return self.encoding[:, :x.size(1)].detach()


In [108]:
class TransformerModel_Enc_Dec(nn.Module):

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel_Enc_Dec, self).__init__()

        # Initialize source and target embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # Initialize positional encoding
        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        # Initialize the transformer model
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        # Initialize the final fully connected layer to project the output to the target vocabulary size
        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):
        
        # Generate masks for source and target sequences
        src_mask = None # No mask for the source sequence
        seq_length = tgt.size(0)
        
        # Generate a no-peak mask for the target sequence to prevent attending to future tokens
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        # Generate source and target masks
        src_mask, tgt_mask = self.generate_mask(src, tgt)

        # Embed the source and target sequences and add positional encodings
        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)

        # Pass the embedded sequences through the transformer model
        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask)

        # Apply the final fully connected layer to project the output to the target vocabulary size
        output = self.fc(output)
        
        return output
    

In [109]:
import numpy as np

src_vocab_size = 20
tgt_vocab_size = 20

d_model = 16
num_heads = 4
num_encoder_layers = 1
num_decoder_layers = 1
d_ff = 20
max_seq_len = 5
dropout = 0


src_data = torch.tensor([[2], [1], [5], [4]])
tgt_data = torch.tensor([[1], [16], [5], [3], [9]]) 


In [110]:
head_dim = d_model//num_heads

In [111]:
torch.manual_seed(0)

transformer = TransformerModel_Enc_Dec(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

In [112]:
tgt_data1 = tgt_data[:-1,:]

output = transformer.forward(src=src_data, tgt=tgt_data1)

In [113]:
print(output, output.shape)

tensor([[[ 0.3775,  0.0734, -0.8096,  0.1469, -0.4905,  0.3739, -0.2705,
          -0.9242,  0.6274, -0.6122,  0.2593, -0.1688,  0.1323,  1.3099,
           0.7643, -0.6106, -0.1371,  0.7552, -0.5415,  0.7423]],

        [[ 0.4344,  0.3561, -0.4982, -0.4326, -0.2377,  0.8758, -1.0641,
          -0.3423,  0.3772, -0.4452,  0.6519, -0.2666, -0.0646,  0.7807,
           0.5272, -0.6573,  0.0976,  1.0804, -0.9153,  1.1743]],

        [[-0.5873,  0.2161, -0.6100, -0.0774, -0.5958,  0.6877, -0.2677,
          -0.4463,  0.5172,  1.1137,  0.1040, -0.3007, -1.1872, -0.2548,
           0.6388,  0.0914,  0.0290, -0.0519,  0.3834, -0.1068]],

        [[ 0.7888,  0.7641, -0.1245, -0.0249,  0.1865, -0.0143, -0.9726,
          -0.9177,  0.3889, -0.3268,  0.1702, -0.8667,  0.2136,  0.3825,
           0.2635, -0.9925, -0.4424,  1.0395, -1.2948,  0.7781]]],
       grad_fn=<ViewBackward0>) torch.Size([4, 1, 20])


In [114]:
state_dict = transformer.state_dict()

## **Manual computation of each layer of the transformer**

In [115]:
# Assuming these are the induvidual token indices after tokenzing the input and ouput
src_data = np.array([[2], [1], [5], [4]])
tgt_data = np.array([[1], [16], [5], [3], [9]])


### 1. Source token embeddings 

In [116]:
# Load the source vocabulary embeddings from the state dictionary
src_vocab_embeds = state_dict["src_embedding.weight"]

# Initialize an array to hold the embeddings for the source data
# Shape: (source sequence length , embedding dimension)
src_embedding = np.zeros((src_data.shape[0], d_model))

# Iterate over each token index in the source data
for i in range(src_data.shape[0]):
        
        word_index = src_data[i]

        # Check if the word index is valid
        if word_index < 0 or word_index >= src_vocab_embeds.shape[0]:
            raise ValueError(f"Invalid word index: {word_index}")
        
         # Retrieve the embedding for the current word index and store it in src_embedding
        src_embedding[i, :] = src_vocab_embeds[word_index, :]

        print(f"Word index: {word_index}, Embedding: {src_vocab_embeds[word_index, :]}")

print()
print(src_embedding.shape)


Word index: [2], Embedding: tensor([[-0.6136,  0.0316, -0.4927,  0.2484,  0.4397,  0.1124,  0.6408,  0.4412,
         -0.1023,  0.7924, -0.2897,  0.0525,  0.5229,  2.3022, -1.4689, -1.5867]])
Word index: [1], Embedding: tensor([[-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,  1.8530,
          0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437]])
Word index: [5], Embedding: tensor([[-9.3348e-02,  6.8705e-01, -8.3832e-01,  8.9182e-04,  8.4189e-01,
         -4.0003e-01,  1.0395e+00,  3.5815e-01, -2.4600e-01,  2.3025e+00,
         -1.8817e+00, -4.9727e-02, -1.0450e+00, -9.5650e-01,  3.3532e-02,
          7.1009e-01]])
Word index: [4], Embedding: tensor([[-0.5692,  0.9200,  1.1108,  1.2899, -1.4782,  2.5672, -0.4731,  0.3356,
         -1.6293, -0.5497, -0.4798, -0.4997, -1.0670,  1.1149, -0.1407,  0.8058]])

(4, 16)


### 2. Source token embeddings + positional embeddings 

Redefining the Positonal Embeddings class into the numpy version

In [117]:
import numpy as np

class PositionalEncoding_np:

    def __init__(self, d_model, max_len=512, dropout=0):

        self.encoding = np.zeros((max_len, d_model))

        position = np.arange(0, max_len).reshape(-1, 1).astype(np.float32)

        div_term = np.exp(np.arange(0, d_model, 2).astype(np.float32) * -(np.log(10000.0) / d_model))

        self.encoding[:, 0::2] = np.sin(position * div_term)
        self.encoding[:, 1::2] = np.cos(position * div_term)
        self.encoding = self.encoding[np.newaxis, :]

    def forward(self, x):
        return self.encoding[0][0]

In [118]:
# Set the model dimension (number of features in the embedding)
d_model = 16

# Set the maximum sequence length
max_seq_len = 5


pe = PositionalEncoding_np(d_model=d_model, max_len=max_seq_len)

# Compute the positional encodings for the source data and add them to the source embeddings
pe_src_embeds = src_embedding + pe.forward(src_data)

pe_src_embeds.shape, pe_src_embeds

((4, 16),
 array([[-0.61358309,  1.03159274, -0.49267703,  1.24841475,  0.43969586,
          1.11241119,  0.64079237,  1.44115627, -0.10230965,  1.79244399,
         -0.28966758,  1.05250749,  0.52286041,  3.30220532, -1.46889389,
         -0.58668876],
        [-1.35265374, -0.69593132,  0.56665051,  1.79350841,  0.59883946,
         -0.55509508, -0.3413603 ,  2.85300612,  0.75018942,  0.41450286,
         -0.17339702,  1.18347792,  1.38936615,  2.58633435,  0.94629836,
          0.15632319],
        [-0.09334823,  1.68705022, -0.83831537,  1.00089182,  0.84189409,
          0.59996545,  1.03946197,  1.3581531 , -0.24600095,  3.30251646,
         -1.88168919,  0.95027298, -1.04497862,  0.04349947,  0.03353186,
          1.71008658],
        [-0.56924802,  1.91997129,  1.11081612,  2.28987384, -1.47817433,
          3.56723285, -0.4731198 ,  1.33555073, -1.62932599,  0.45025635,
         -0.47983426,  0.50031784, -1.06698   ,  2.11493957, -0.14067143,
          1.80575365]]))

In [119]:
x_enc = pe_src_embeds

### 3. Q,K,V matrices from the model's intialised weights

We define a function for getting these vectors in any number of layers in the encoder or decoder

In [120]:
import numpy as np

def get_QKV_matrices(x, layer_num, state_dict, num_heads, is_decoder=False):
    """
    Extracts the Q, K, and V matrices from the encoded input for a given transformer layer.

    Args:
    - x: Encoded input tensor (numpy array).
    - layer_num: Layer number from which to extract the Q, K, and V matrices.
    - state_dict: Dictionary containing the model weights.
    - num_heads: Number of attention heads.
    - is_decoder: Boolean indicating if the layer is in the decoder.

    Returns:
    - Q: Query matrix after reshaping.
    - K: Key matrix after reshaping.
    - V: Value matrix after reshaping.
    """
    
    # Set query, key, and value encoding to the input encoding
    query = key = value = x

    # Determine the target length and embedding dimension from the input encoding
    tgt_len, embed_dim = x.shape

    # Determine the layer type (encoder or decoder) and extract the weights and biases accordingly
    layer_type = "decoder" if is_decoder else "encoder"
    W = state_dict[f"transformer.{layer_type}.layers.{layer_num}.self_attn.in_proj_weight"].numpy()
    b = state_dict[f"transformer.{layer_type}.layers.{layer_num}.self_attn.in_proj_bias"].numpy()

    # Calculate the dimension for each attention head
    head_dim = embed_dim // num_heads

    # Compute the product of the input encoding and the transpose of the weight matrix
    tempop1 = np.matmul(query, W.T)

    # Split the resulting matrix into Q, K, and V matrices
    Q = tempop1[:, 0:embed_dim]
    K = tempop1[:, embed_dim:2*embed_dim]
    V = tempop1[:, 2*embed_dim:3*embed_dim]

    # Print the shapes of the Q, K, and V matrices
    print("Q_shape = ", Q.shape)
    print("K_shape = ", K.shape)
    print("V_shape = ", V.shape)
    print()

    print("After reshaping... \n")

    # Reshape and transpose the Q, K, and V matrices for multi-head attention
    Q = np.transpose(np.reshape(Q, (tgt_len, num_heads, head_dim)), (1, 0, 2))
    K = np.transpose(np.reshape(K, (K.shape[0], num_heads, head_dim)), (1, 0, 2))
    V = np.transpose(np.reshape(V, (V.shape[0], num_heads, head_dim)), (1, 0, 2))

    # Print the shapes of the reshaped Q, K, and V matrices
    print("Q_shape = ", Q.shape)
    print("K_shape = ", K.shape)
    print("V_shape = ", V.shape)

    return Q, K, V

In [121]:
tgt_len, embed_dim = x_enc.shape

Q_enc , K_enc, V_enc =  get_QKV_matrices(x = x_enc, layer_num = 0, state_dict = state_dict, num_heads = num_heads, is_decoder=False)

Q_shape =  (4, 16)
K_shape =  (4, 16)
V_shape =  (4, 16)

After reshaping... 

Q_shape =  (4, 4, 4)
K_shape =  (4, 4, 4)
V_shape =  (4, 4, 4)


### 4. Self-Attention in the Enocder

In [122]:
import numpy as np
import math

def calculate_attention(Q, K, V, attn_mask = None):
    """
    Calculates the scaled dot-product attention for a given set of Q, K, and V matrices.

    Args:
    - Q: Query matrix.
    - K: Key matrix.
    - V: Value matrix.

    Returns:
    - attn_output: Output of the attention calculation.
    """

    # Calculate the scale factor for the dot product
    scale_factor = 1 / math.sqrt(Q.shape[-1]) 

    # Transpose the K matrix for matrix multiplication
    K_T = np.transpose(K, axes=(0, 2, 1))

    # Calculate the scaled dot product of Q and K
    attn_weight = Q @ K_T * scale_factor

    if attn_mask is not None:
        attn_weight += attn_mask

    # Apply softmax to the attention weights
    exp_attn_weight = np.exp(attn_weight)

    sum_exp_attn_weight = np.sum(exp_attn_weight, axis=-1, keepdims=True)

    softmax_attn_weight = exp_attn_weight / sum_exp_attn_weight

    # Calculate the attention output
    attn_output = softmax_attn_weight @ V

    return attn_output

In [123]:
# Calculate attention
attn_output_enc = calculate_attention(Q_enc, K_enc, V_enc)

In [124]:
attn_output_enc, attn_output_enc.shape

(array([[[ 0.71996273,  0.37472799,  0.19898526, -0.09682794],
         [ 0.74778944,  0.41424775,  0.14055336, -0.09750267],
         [ 0.6835169 ,  0.43522334,  0.44485319, -0.15091043],
         [ 0.55486831, -0.18135746, -0.03144971,  0.15490705]],
 
        [[ 0.72312542,  0.11965131, -0.6167594 ,  0.41149312],
         [ 0.8335319 ,  0.10672248, -0.50738735,  0.41318483],
         [ 0.68540564,  0.1174512 , -0.58773575,  0.70770555],
         [ 0.65851092,  0.1460082 , -0.62973259,  0.47600019]],
 
        [[-0.17895638, -0.97088941,  0.21564848,  0.34026238],
         [-0.15921383, -0.95981101,  0.22985974,  0.33606614],
         [ 0.17079554, -0.57683205,  0.32521921,  0.41776349],
         [-0.21387928, -0.85617058,  0.13442093,  0.39797621]],
 
        [[ 0.45370406,  0.97827887,  0.88422871, -0.11814313],
         [ 0.39785863,  0.99090849,  0.84784164, -0.16750662],
         [ 0.97294568,  0.8204363 ,  1.46646822, -0.43123428],
         [ 0.54188754,  0.98235485,  0.9537199

In [125]:
def reshape_attention_op(attn_output, num_heads, tgt_len, head_dim):
    """
    Reshapes the attention output matrix to match the desired shape.

    Args:
    - attn_output: Attention output matrix.
    - num_heads: Number of attention heads.
    - tgt_len: Length of the target sequence.
    - head_dim: Dimension of each attention head.

    Returns:
    - final_attn_sa_op: Reshaped attention output.
    """

    # Transpose the attention output dimensions
    attn_output_permuted_sa = np.transpose(attn_output, axes=(0, 1, 2))

    # Compute the number of elements in the reshaped output
    numh_tgt_len, embed_dim = num_heads * tgt_len, head_dim

    # Reshape the attention output
    attn_output_reshaped_sa = attn_output_permuted_sa.reshape(numh_tgt_len, embed_dim)

    # Initialize the final attention output matrix with zeros
    final_attn_sa_op = np.zeros(attn_output_reshaped_sa.shape)

    i = 0

    while i < attn_output_reshaped_sa.shape[1]:
        for j in range(attn_output_reshaped_sa.shape[1]):

            # Compute the position in the final matrix
            pos = i * attn_output_reshaped_sa.shape[1] + j

            # Compute the block index and offset for the current column
            blk = j * attn_output_reshaped_sa.shape[1]

            offset = i

            # Assign the corresponding value from the reshaped attention output to the final matrix
            final_attn_sa_op[pos] = attn_output_reshaped_sa[blk + offset]

        i += 1

    final_attn_sa_op = final_attn_sa_op.reshape(attn_output_reshaped_sa.shape[1], -1)

    return final_attn_sa_op


In [126]:
attn_sa_enc = reshape_attention_op(attn_output_enc, num_heads, tgt_len, head_dim)

In [127]:
attn_sa_enc, attn_sa_enc.shape

(array([[ 0.71996273,  0.37472799,  0.19898526, -0.09682794,  0.72312542,
          0.11965131, -0.6167594 ,  0.41149312, -0.17895638, -0.97088941,
          0.21564848,  0.34026238,  0.45370406,  0.97827887,  0.88422871,
         -0.11814313],
        [ 0.74778944,  0.41424775,  0.14055336, -0.09750267,  0.8335319 ,
          0.10672248, -0.50738735,  0.41318483, -0.15921383, -0.95981101,
          0.22985974,  0.33606614,  0.39785863,  0.99090849,  0.84784164,
         -0.16750662],
        [ 0.6835169 ,  0.43522334,  0.44485319, -0.15091043,  0.68540564,
          0.1174512 , -0.58773575,  0.70770555,  0.17079554, -0.57683205,
          0.32521921,  0.41776349,  0.97294568,  0.8204363 ,  1.46646822,
         -0.43123428],
        [ 0.55486831, -0.18135746, -0.03144971,  0.15490705,  0.65851092,
          0.1460082 , -0.62973259,  0.47600019, -0.21387928, -0.85617058,
          0.13442093,  0.39797621,  0.54188754,  0.98235485,  0.95371994,
         -0.1078992 ]]),
 (4, 16))

### 5. Post self attention in the encoder block

In [128]:
def linear_layer_forward(x_enc, final_attn_sa_op, state_dict, layer_num, is_decoder = False):
    """
    Forward pass through an encoder layer of a Transformer model.

    Args:
    - x_enc: Input tensor for the encoder layer.
    - final_attn_sa_op: Reshaped attention output tensor.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the encoder layer.

    Returns:
    - output_enc_final: Output tensor of the encoder layer.
    """

    # Obtain the parameters for the attention output projection
    layer_type = "decoder" if is_decoder else "encoder"
    weight_enc = state_dict["transformer.{layer_type}.layers.{layer_num}.self_attn.out_proj.weight"].numpy()
    bias_enc = state_dict["transformer.{layer_type}.layers.{layer_num}.self_attn.out_proj.bias"].numpy()

    # Output projection of the attention values
    op_enc_1 = np.matmul(final_attn_sa_op, weight_enc.T) + bias_enc

    # Residual connection 1
    output_enc_1 = op_enc_1 + x_enc
    

    #### Layer Norm 1 ####
    norm_weight = state_dict["transformer.{layer_type}.layers.{layer_num}.norm1.weight"].numpy()
    norm_bias = state_dict["transformer.{layer_type}.layers.{layer_num}.norm1.bias"].numpy()
    linear_result_enc_1 = output_enc_1 * norm_weight + norm_bias

    # Compute mean and standard deviation for Layer Norm
    mean = np.mean(linear_result_enc_1, axis=-1, keepdims=True)
    std = np.std(linear_result_enc_1, axis=-1, keepdims=True)
    epsilon = 1e-5 
    linear_op_enc_1 = (linear_result_enc_1 - mean) / (std + epsilon)
    

    # Obtain the parameters for the linear projections
    linear1_weight = state_dict["transformer.{layer_type}.layers.{layer_num}.linear1.weight"].numpy()
    linear1_bias = state_dict["transformer.{layer_type}.layers.{layer_num}.linear1.bias"].numpy()

    linear2_weight = state_dict["transformer.{layer_type}.layers.{layer_num}.linear2.weight"].numpy()
    linear2_bias = state_dict["transformer.{layer_type}.layers.{layer_num}.linear2.bias"].numpy()

    # Linear projection 1
    op_enc_1 = np.matmul(linear_op_enc_1, linear1_weight.T) + linear1_bias
    # ReLU activation
    op_enc_1_relu = np.maximum(op_enc_1, 0)
    # Linear projection 2
    op_enc_2 = np.matmul(op_enc_1_relu, linear2_weight.T) + linear2_bias

    # Residual connection 2
    output_enc_2 = op_enc_2 + linear_op_enc_1

    # Layer Norm 2
    norm_weight = state_dict["transformer.{layer_type}.layers.{layer_num}.norm2.weight"].numpy()
    norm_bias = state_dict["transformer.{layer_type}.layers.{layer_num}.norm1.bias"].numpy()
    linear_result_enc_2 = output_enc_2 * norm_weight + norm_bias

    # Compute mean and standard deviation for Layer Norm
    mean = np.mean(linear_result_enc_2, axis=-1, keepdims=True)
    std = np.std(linear_result_enc_2, axis=-1, keepdims=True)
    epsilon = 1e-5 
    linear_op_enc_2 = (linear_result_enc_2 - mean) / (std + epsilon)

    # Output of the encoder layer
    output_enc_final = linear_op_enc_2

    return output_enc_final


In [129]:
import numpy as np

def attention_output_projection(final_attn_sa_op, state_dict, layer_num, is_decoder = False, is_crossattn = False):
    """
    Perform the attention output projection.

    Args:
    - final_attn_sa_op: Reshaped attention output tensor.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the encoder layer.

    Returns:
    - output_enc_1: Output of the attention output projection.
    """
    # Obtain the parameters for the attention output projection
    layer_type = "decoder" if is_decoder else "encoder"
    attn_type = "multihead_attn" if is_crossattn else "self_attn"

    weight_key = "transformer.{}.layers.{}.{}.out_proj.weight".format(layer_type, layer_num, attn_type)
    bias_key = "transformer.{}.layers.{}.{}.out_proj.bias".format(layer_type, layer_num, attn_type)

    weight_op = state_dict[weight_key].numpy()
    bias_op = state_dict[bias_key].numpy()

    # Output projection of the attention values
    output_1 = np.matmul(final_attn_sa_op, weight_op.T) + bias_op

    return output_1

def layer_norm(input_tensor, state_dict, layer_num, suffix, is_decoder=False):
    """
    Apply layer normalization.

    Args:
    - input_tensor: Input tensor to be normalized.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the encoder layer.
    - suffix: Suffix for parameter keys.

    Returns:
    - normalized_tensor: Normalized tensor.
    """

    layer_type = "decoder" if is_decoder else "encoder"

    weight_key = "transformer.{}.layers.{}.{}.weight".format(layer_type, layer_num, suffix)
    bias_key = "transformer.{}.layers.{}.{}.bias".format(layer_type, layer_num, suffix)

    norm_weight = state_dict[weight_key].numpy()
    norm_bias = state_dict[bias_key].numpy()

    linear_result = input_tensor * norm_weight + norm_bias

    # Compute mean and standard deviation for Layer Norm
    mean = np.mean(linear_result, axis=-1, keepdims=True)
    std = np.std(linear_result, axis=-1, keepdims=True)
    epsilon = 1e-5 
    normalized_tensor = (linear_result - mean) / (std + epsilon)

    return normalized_tensor

def linear_relu_linear(input_tensor, state_dict, layer_num, is_decoder=False):
    """
    Apply Linear-Relu-Linear projection.

    Args:
    - input_tensor: Input tensor to be projected.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the encoder layer.

    Returns:
    - output_enc_final: Output tensor after Linear-Relu-Linear projection.
    """

    # Obtain the parameters for the linear projections
    layer_type = "decoder" if is_decoder else "encoder"

    weight1_key = "transformer.{}.layers.{}.linear1.weight".format(layer_type, layer_num)
    bias1_key = "transformer.{}.layers.{}.linear1.bias".format(layer_type, layer_num)
    weight2_key = "transformer.{}.layers.{}.linear2.weight".format(layer_type, layer_num)
    bias2_key = "transformer.{}.layers.{}.linear2.bias".format(layer_type, layer_num)

    linear1_weight = state_dict[weight1_key].numpy()
    linear1_bias = state_dict[bias1_key].numpy()
    linear2_weight = state_dict[weight2_key].numpy()
    linear2_bias = state_dict[bias2_key].numpy()

    # Linear projection 1
    op_enc_1 = np.matmul(input_tensor, linear1_weight.T) + linear1_bias
    # ReLU activation
    op_enc_1_relu = np.maximum(op_enc_1, 0)
    # Linear projection 2
    op_enc_2 = np.matmul(op_enc_1_relu, linear2_weight.T) + linear2_bias

    return op_enc_2

def encoder_layer_forward(x_enc, final_attn_sa_op, state_dict, layer_num):
    """
    Forward pass through an encoder layer of a Transformer model.

    Args:
    - x_enc: Input tensor for the encoder layer.
    - final_attn_sa_op: Reshaped attention output tensor.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the encoder layer.

    Returns:
    - output_enc_final: Output tensor of the encoder layer.
    """

    # Attention output projection
    output_enc_1 = attention_output_projection(final_attn_sa_op, state_dict, layer_num)

    # Residual connection 1
    output_enc_1 += x_enc

    # Layer Norm 1
    normalized_tensor_1 = layer_norm(output_enc_1, state_dict, layer_num, "norm1")

    # Linear-Relu-Linear projection
    output_enc_2 = linear_relu_linear(normalized_tensor_1, state_dict, layer_num)

    # Residual connection 2
    output_enc_2 += normalized_tensor_1

    # Layer Norm 2
    output_enc_final = layer_norm(output_enc_2, state_dict, layer_num, "norm2")

    return output_enc_final


In [130]:
output_enc_final = encoder_layer_forward(x_enc, attn_sa_enc, state_dict, layer_num = 0)

In [131]:
output_enc_final, output_enc_final.shape

(array([[ 0.07763966, -0.81340241, -0.63372394,  0.23144549, -0.17759687,
          0.3632295 , -1.3014882 ,  1.52573876, -0.57451076,  1.58645178,
          0.40404961,  0.25715019, -1.01007227,  1.90874712, -1.71313098,
         -0.13052668],
        [-0.56655927, -1.69010877, -0.29210758,  1.21460146, -0.11027101,
          0.00789656, -2.11566028,  1.65359012,  0.08810835,  0.56830719,
          0.03861744,  0.5650082 , -0.14632035,  1.70270482, -0.61037473,
         -0.30743213],
        [ 0.11359249, -0.5709097 , -1.03994491, -0.1699108 ,  0.97662986,
         -0.3185164 , -0.41645207,  1.2750162 , -1.37529657,  2.30959525,
         -0.54989653,  0.85788181, -1.11259539, -0.98347755, -0.09901768,
          1.10330199],
        [ 0.18098552,  0.02234932,  0.57643494,  0.78534802, -1.54732829,
          0.83277342, -1.94984843,  0.44517712, -0.74449741,  1.16907805,
          0.10308054, -0.41452632, -1.45966998,  1.31105764, -0.58765698,
          1.27724283]]),
 (4, 16))

In [132]:
x_enc = output_enc_final

## Decoder block


### Self attention outputs from a decoder block

### 6. Target token embeddings 

In [133]:
# Considering only the tokens except the last one for the next word prediction (auto-regressive task)
tgt_data1 = tgt_data[:-1,:]

In [134]:
# Extract target vocabulary embeddings from the model's state dictionary
tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

# Initialize a numpy array to store target embeddings
tgt_embedding = np.zeros((tgt_data1.shape[0], d_model))

# Iterate over each target token in tgt_data1
for i in range(tgt_data1.shape[0]):
    # Get the word index of the current target token
    word_index = tgt_data1[i]
    
    # Check if the word index is valid
    if word_index < 0 or word_index >= tgt_vocab_embeds.shape[0]:
        raise ValueError(f"Invalid word index: {word_index}")
    
    # Retrieve the embedding corresponding to the word index and assign it to tgt_embedding
    tgt_embedding[i, :] = tgt_vocab_embeds[word_index, :]
    
    # Print the word index and its corresponding embedding
    print(f"Word index: {word_index}, Embedding: {tgt_vocab_embeds[word_index, :]}")

# Print the shape of tgt_embedding
print()
print(tgt_embedding.shape)


Word index: [1], Embedding: tensor([[ 0.6442,  3.9300, -0.1244,  0.2953,  0.3827, -0.5497, -0.9940,  1.3459,
          1.9457, -1.2904, -2.3495, -2.0689,  0.9094, -0.6946,  1.9595, -1.1038]])
Word index: [16], Embedding: tensor([[-0.8733,  0.0043, -1.2579, -1.0845,  0.7530,  0.3236, -0.2750,  1.3056,
          0.2118,  0.2720, -0.9268, -2.7330, -0.5642, -0.2740,  0.1398,  0.5086]])
Word index: [5], Embedding: tensor([[-0.7645,  0.2408,  0.1664, -2.2318,  1.3892, -0.5023,  1.6797, -1.0240,
          1.6859, -1.2177,  0.7650,  1.1971, -0.7128, -0.0656,  2.2050,  1.7852]])
Word index: [3], Embedding: tensor([[ 0.4990,  0.8780,  0.3894,  1.4625,  0.4795, -0.5334, -0.0347,  0.6573,
         -0.3112, -0.5620, -0.4835, -1.2721, -0.1740,  0.5541, -0.1817, -0.2345]])

(4, 16)


### 7. Target token embeddings + positional embeddings 

In [135]:
pe = PositionalEncoding_np(d_model=d_model, max_len=max_seq_len)

pe_tgt_embeds = tgt_embedding + pe.forward(tgt_data)


pe_tgt_embeds.shape, pe_tgt_embeds

((4, 16),
 array([[ 0.64423001,  4.93000388, -0.12442428,  1.29534167,  0.38265419,
          0.45027864, -0.99403578,  2.34593689,  1.94566822, -0.29036391,
         -2.3494761 , -1.06886196,  0.90942109,  0.30537993,  1.95945716,
         -0.10382783],
        [-0.87330669,  1.00426142, -1.25788677, -0.08446777,  0.7529794 ,
          1.32364774, -0.27501002,  2.30561185,  0.21175182,  1.27196231,
         -0.92684317, -1.7329998 , -0.5641737 ,  0.72600037,  0.13978058,
          1.50856197],
        [-0.76447284,  1.24084058,  0.16642573, -1.23181415,  1.38921094,
          0.49766743,  1.67969298, -0.02395296,  1.68592429, -0.21769202,
          0.76496333,  2.19711864, -0.71278685,  0.93442459,  2.20497036,
          2.78517103],
        [ 0.49895304,  1.87799746,  0.38944435,  2.4625175 ,  0.47950602,
          0.46660012, -0.03465135,  1.65729696, -0.31122431,  0.43799645,
         -0.48349261, -0.27211261, -0.17401844,  1.55411685, -0.18165524,
          0.76552661]]))

In [136]:
x_dec = pe_tgt_embeds

### 8. Self Attention in Decoder **(with mask)**

In [137]:
seq_length = 4
tgt_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')

tgt_mask

array([[False,  True,  True,  True],
       [False, False,  True,  True],
       [False, False, False,  True],
       [False, False, False, False]])

In [138]:
Q_dec , K_dec, V_dec =  get_QKV_matrices(x = x_dec, layer_num = 0, state_dict = state_dict, num_heads = num_heads, is_decoder=True)

Q_shape =  (4, 16)
K_shape =  (4, 16)
V_shape =  (4, 16)

After reshaping... 

Q_shape =  (4, 4, 4)
K_shape =  (4, 4, 4)
V_shape =  (4, 4, 4)


In [139]:
K_dec

array([[[ 7.83983860e-01, -2.89777000e+00, -1.99690011e+00,
         -8.15224661e-02],
        [ 1.19432158e-01, -1.32870414e+00,  6.87518232e-02,
          7.77369267e-01],
        [ 2.24591093e+00,  1.18072686e-01, -8.42719821e-01,
         -5.20880030e-01],
        [-2.40537342e-01, -1.17889995e+00, -3.21789001e-01,
          3.00278529e-01]],

       [[-2.24577325e-01, -1.41271065e-01, -1.54248077e-02,
         -4.35614408e-01],
        [ 1.03550371e-01, -1.66130565e+00,  5.08676441e-01,
         -8.00304722e-01],
        [-8.19911916e-01,  9.25537349e-01,  3.23517940e-02,
          7.18360881e-01],
        [-4.52862487e-01, -1.71429861e-01,  7.39203197e-01,
         -2.21240286e-01]],

       [[ 4.39262634e-01,  1.05416450e+00,  2.56656161e-03,
         -2.55351944e+00],
        [-2.82534251e-01,  1.19934027e+00, -4.78174562e-01,
         -6.89177429e-01],
        [ 5.96152102e-01, -2.51479706e-01,  2.78669174e-01,
         -1.73430184e+00],
        [ 8.34376828e-01,  4.87032609e-

### Preparing the mask for decoder attention mechanism

In [140]:
# Initialize an attention mask with zeros having the same shape as tgt_mask
attn_mask = np.zeros(tgt_mask.shape)

if tgt_mask is not None:

    if tgt_mask.dtype == 'bool':

        # Convert boolean mask to a float mask with -inf for True and 0 for False
        masked_tensor = tgt_mask.astype(float)
        masked_tensor[masked_tensor == 1] = -np.inf

        tgt_mask = masked_tensor

        attn_mask += tgt_mask

    else:
        # If tgt_mask's dtype is not boolean, directly add tgt_mask to the attention mask
        attn_mask += tgt_mask


In [141]:
attn_mask

array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]])

In [142]:
self_attn_output_dec = calculate_attention(Q_dec, K_dec, V_dec, attn_mask = attn_mask)

In [143]:
self_attn_output_dec, self_attn_output_dec.shape



(array([[[-1.02595533e+00,  9.66156934e-01, -2.29126390e+00,
           1.80570103e+00],
         [-4.94813720e-01,  3.63255462e-02, -1.18165437e+00,
           1.23624901e+00],
         [ 3.36763840e-01, -5.12013597e-01,  8.94194164e-02,
           5.29699393e-01],
         [ 4.78021100e-01, -4.68884973e-02, -2.54583645e-01,
           4.71600683e-01]],
 
        [[-2.39826266e-02, -1.33530231e+00, -1.67510994e+00,
          -9.04894019e-01],
         [-8.35024099e-02, -1.15296021e+00, -1.09796134e+00,
          -5.63298511e-01],
         [ 1.44866035e-01, -1.71497100e-01,  2.57711368e-01,
          -4.27455299e-01],
         [-5.85121465e-02, -7.90012909e-01, -3.01850574e-01,
          -4.11758885e-01]],
 
        [[ 1.53036519e+00, -1.85140923e+00, -1.51471969e+00,
           9.21301524e-01],
         [ 1.37871482e+00, -1.49100273e+00, -1.79188210e+00,
           7.92854103e-01],
         [ 1.35127614e+00, -8.67529277e-01, -1.43847057e+00,
          -6.94138377e-01],
         [ 1.29

In [144]:
attn_sa_dec = reshape_attention_op(self_attn_output_dec, num_heads, tgt_len, head_dim)

In [145]:
attn_sa_dec.shape, attn_sa_dec

((4, 16),
 array([[-1.02595533e+00,  9.66156934e-01, -2.29126390e+00,
          1.80570103e+00, -2.39826266e-02, -1.33530231e+00,
         -1.67510994e+00, -9.04894019e-01,  1.53036519e+00,
         -1.85140923e+00, -1.51471969e+00,  9.21301524e-01,
          5.28027792e-01,  7.11727409e-01, -2.57351948e-01,
          1.51332427e+00],
        [-4.94813720e-01,  3.63255462e-02, -1.18165437e+00,
          1.23624901e+00, -8.35024099e-02, -1.15296021e+00,
         -1.09796134e+00, -5.63298511e-01,  1.37871482e+00,
         -1.49100273e+00, -1.79188210e+00,  7.92854103e-01,
          4.19327992e-01,  7.33425132e-01,  5.74104863e-02,
          1.38451382e-01],
        [ 3.36763840e-01, -5.12013597e-01,  8.94194164e-02,
          5.29699393e-01,  1.44866035e-01, -1.71497100e-01,
          2.57711368e-01, -4.27455299e-01,  1.35127614e+00,
         -8.67529277e-01, -1.43847057e+00, -6.94138377e-01,
          1.55118181e-01,  2.82364028e-03,  1.78452212e-01,
          6.81186860e-02],
        [

In [None]:
/Users/sreevaatsav/Downloads/KGs/MHA_book.ipynb

### 9. Post self attention in the decoder self attention block

In [146]:
def decoder_layer_forward1(x_dec, final_attn_sa_op, state_dict, layer_num):


    # Attention output projection
    output_dec_1 = attention_output_projection(final_attn_sa_op, state_dict, layer_num,is_decoder = True)

    # Residual connection 1
    output_dec_1 += x_dec

    # Layer Norm 1
    normalized_tensor_1 = layer_norm(output_dec_1, state_dict, layer_num, "norm1", is_decoder = True)


    return normalized_tensor_1


In [147]:
output_dec_1 = decoder_layer_forward1(x_dec, attn_sa_dec, state_dict, layer_num = 0)

In [148]:
output_dec_1, output_dec_1.shape

(array([[ 0.79080436,  2.61675492,  0.19728347,  0.11469253, -0.84743373,
         -0.22291594,  0.25940573,  0.67274157, -0.13527261, -0.38400555,
         -1.56532808, -1.85249906,  0.65941381, -0.47682196,  0.56739699,
         -0.39421645],
        [-0.29441922,  1.09712817, -0.40351996,  0.22616283, -0.75575188,
          0.25824741,  0.46990197,  1.71531652, -0.62644138,  0.81040219,
         -1.21026324, -2.66499552,  0.36661349,  0.43849546, -0.25539105,
          0.8285142 ],
        [-2.00405224,  0.13751043,  0.09835611, -1.51830931, -0.80674191,
         -0.84878997,  0.94845981,  0.49026421,  1.10373596, -0.70779382,
         -0.45183934,  0.9925127 , -0.58471128,  1.09218878,  0.49342051,
          1.56578937],
        [ 0.31189761,  1.17028354,  0.41392054,  2.03995892, -1.06850928,
         -0.54656387, -0.17191765,  1.32368864, -0.99201868, -0.01623896,
         -0.9910646 , -1.69925982,  0.12088679,  1.09146899, -0.95770799,
         -0.02882418]]),
 (4, 16))

### 10.**Cross attention** in decoder

In [149]:
memory = x_enc

In [150]:
query_dec_ca = output_dec_1

key_dec_ca, value_dec_ca = memory, memory

In [151]:
def get_QKV_crossattn(query_tensor, key_tensor, state_dict, layer_num, num_heads):
    """
    Compute the Q, K, and V matrices for cross-attention.

    Args:
    - query_tensor: The query tensor for cross-attention.
    - key_tensor: The key tensor for cross-attention.
    - state_dict: State dictionary containing the parameters of the model.
    - layer_num: Index of the decoder layer.
    - num_heads: Number of attention heads.

    Returns:
    - Q: Query matrix for cross-attention.
    - K: Key matrix for cross-attention.
    - V: Value matrix for cross-attention.
    """
    # Get the shape of the query tensor
    tgt_len, embed_dim = query_tensor.shape

    # Obtain the weight and bias matrices for in-projection
    W_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)].numpy()
    b_dec_ca = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)].numpy()

    # Compute the dimension of each attention head
    head_dim = embed_dim // num_heads

    # Compute the dimension of embeddings
    E = query_tensor.shape[-1]

    # Split the weight matrix into W_q and W_kv
    split_indices = [E, E * 2]
    W_q = W_dec_ca[:split_indices[0], :]
    W_kv = W_dec_ca[split_indices[0]:, :]

    # Compute Q matrix
    Q = np.matmul(query_tensor, W_q.T)

    # Compute K and V matrices
    KV_op = np.matmul(key_tensor, W_kv.T)
    K, V = KV_op[:, 0:embed_dim], KV_op[:, embed_dim:2*embed_dim]

    # Reshape Q, K, and V matrices
    Q = np.transpose(np.reshape(Q, (tgt_len, num_heads, head_dim)), (1, 0, 2))
    K = np.transpose(np.reshape(K, (K.shape[0], num_heads, head_dim)), (1, 0, 2))
    V = np.transpose(np.reshape(V, (V.shape[0], num_heads, head_dim)), (1, 0, 2))

    return Q, K, V

In [152]:
layer_num = 0

Q_dec1, K_dec1, V_dec1 = get_QKV_crossattn(query_dec_ca, key_dec_ca, state_dict, layer_num, num_heads)


In [153]:
# V_dec1
K_dec_ca1_T = np.transpose(K_dec1, axes=(0, 2, 1))

In [154]:
cross_attn_output_dec = calculate_attention(Q_dec1, K_dec1, V_dec1)

In [155]:
cross_attn_output_dec

array([[[-0.18553441, -0.57655386, -0.69666788,  0.25312486],
        [-0.23444545, -0.61369421, -0.66903157,  0.30315544],
        [-0.48267096, -0.72678632, -0.47050186,  0.51914174],
        [-0.24223902, -0.64934472, -0.69094346,  0.31772663]],

       [[-0.4313023 ,  0.17086466, -0.28128801, -0.18745656],
        [-0.39735754,  0.08389441, -0.22194945, -0.27801537],
        [-0.32980284,  0.06160475, -0.13966224, -0.39402187],
        [-0.40939   ,  0.01804331, -0.23653818, -0.28511785]],

       [[ 0.79633162, -0.68718596, -0.10476245,  0.01925739],
        [ 0.79108095, -0.68322047, -0.15286889,  0.04382875],
        [ 0.79858936, -0.68131396, -0.14482916,  0.03213887],
        [ 0.78334287, -0.6882062 , -0.17583734,  0.05549425]],

       [[ 0.17863263, -0.71823765, -0.38514964, -0.2301647 ],
        [ 0.23649462, -0.8250898 , -0.31696894, -0.29983348],
        [ 0.19556961, -0.76403825, -0.31572912, -0.2474797 ],
        [ 0.23015667, -0.83680755, -0.29030636, -0.29460031]]])

In [156]:
attn_ca_dec = reshape_attention_op(cross_attn_output_dec, num_heads, tgt_len, head_dim)

In [157]:
attn_ca_dec, attn_ca_dec.shape

(array([[-0.18553441, -0.57655386, -0.69666788,  0.25312486, -0.4313023 ,
          0.17086466, -0.28128801, -0.18745656,  0.79633162, -0.68718596,
         -0.10476245,  0.01925739,  0.17863263, -0.71823765, -0.38514964,
         -0.2301647 ],
        [-0.23444545, -0.61369421, -0.66903157,  0.30315544, -0.39735754,
          0.08389441, -0.22194945, -0.27801537,  0.79108095, -0.68322047,
         -0.15286889,  0.04382875,  0.23649462, -0.8250898 , -0.31696894,
         -0.29983348],
        [-0.48267096, -0.72678632, -0.47050186,  0.51914174, -0.32980284,
          0.06160475, -0.13966224, -0.39402187,  0.79858936, -0.68131396,
         -0.14482916,  0.03213887,  0.19556961, -0.76403825, -0.31572912,
         -0.2474797 ],
        [-0.24223902, -0.64934472, -0.69094346,  0.31772663, -0.40939   ,
          0.01804331, -0.23653818, -0.28511785,  0.78334287, -0.6882062 ,
         -0.17583734,  0.05549425,  0.23015667, -0.83680755, -0.29030636,
         -0.29460031]]),
 (4, 16))

### 11. Post cross attention in the decoder block

In [168]:
def decoder_layer_forward2(self_attn_dec, final_attn_ca_op, state_dict, layer_num):


    # Attention output projection
    output_dec_1 = attention_output_projection(final_attn_ca_op, state_dict, layer_num,is_decoder = True, is_crossattn = True)

    # Residual connection 2
    output_dec_1 += self_attn_dec

    # Layer Norm 1
    normalized_tensor_2 = layer_norm(output_dec_1, state_dict, layer_num, "norm2", is_decoder = True)

    output_dec_2 = linear_relu_linear(normalized_tensor_2, state_dict, layer_num, is_decoder = True)

    ff_dec = output_dec_2

    # Residual connection 3
    output_dec_3 = normalized_tensor_2 + ff_dec

    # Layer Norm 3
    output_dec_3 = layer_norm(output_dec_3, state_dict, layer_num, "norm3", is_decoder = True)



    return output_dec_3


In [169]:
linear_op_dec_3 = decoder_layer_forward2(output_dec_1, attn_ca_dec, state_dict, layer_num)

In [170]:
linear_op_dec_3

array([[-0.5356827 ,  1.93180892, -0.51183151, -0.06886474, -0.35834298,
         0.006337  ,  0.74473562,  0.38577164,  0.36656192,  0.24361149,
        -1.51696656, -1.19693871,  1.24391757, -1.00695612,  1.61225937,
        -1.33942023],
       [-0.95041354,  0.78181972, -0.90229085, -0.16324231, -0.3592311 ,
         0.16954442,  0.85607012,  1.72101943, -0.56157297,  1.35115966,
        -1.33890043, -2.14980348,  0.58189228,  0.48414693,  0.76356128,
        -0.28375916],
       [-1.7267202 ,  0.12770976, -0.20487979, -1.94317655,  0.58656848,
        -1.09442334,  0.84953304,  0.56638679,  0.32088158, -0.55749931,
        -1.0709999 ,  0.7707259 , -0.08021092,  1.04794369,  1.6733852 ,
         0.73477555],
       [-0.18048601,  1.14506299, -0.14022405,  1.80746055, -0.66768442,
        -0.59171444,  0.12012847,  1.24848012, -1.23003434,  0.65160384,
        -1.4823933 , -1.11823431,  0.06295639,  1.53559277, -0.04201097,
        -1.11850329]])

### 12. Feed Forward Layer

In [161]:
dec_output_final = linear_op_dec_3

In [162]:
W_ff = state_dict["fc.weight"].numpy()
b_ff = state_dict["fc.bias"].numpy()

final_op = dec_output_final@W_ff.T + b_ff

In [163]:
final_op, final_op.shape

(array([[ 0.37746685,  0.07337387, -0.80958893,  0.14691591, -0.49054451,
          0.37394687, -0.27048344, -0.92419295,  0.62737914, -0.61220299,
          0.25932976, -0.16877907,  0.13227677,  1.30987519,  0.76432395,
         -0.61062077, -0.13705934,  0.7552388 , -0.54154864,  0.74224788],
        [ 0.43438884,  0.35607952, -0.49823122, -0.43256203, -0.23772529,
          0.87583749, -1.06413742, -0.34228145,  0.37721902, -0.44517008,
          0.65188803, -0.26663242, -0.06459797,  0.7806719 ,  0.52723765,
         -0.65729075,  0.09763294,  1.08035112, -0.91532704,  1.17432552],
        [-0.58731561,  0.21614686, -0.60998807, -0.07736154, -0.59575491,
          0.68770998, -0.26769391, -0.44631606,  0.51722351,  1.11366001,
          0.1039768 , -0.30067784, -1.18722492, -0.25477752,  0.63879546,
          0.09141405,  0.02895858, -0.05193275,  0.38341005, -0.10680178],
        [ 0.78883248,  0.76411429, -0.12451219, -0.02487891,  0.18648671,
         -0.01425053, -0.97255592, 

### This is the output from nn.Transformers :- 

In [164]:
print(output, output.shape)

tensor([[[ 0.3775,  0.0734, -0.8096,  0.1469, -0.4905,  0.3739, -0.2705,
          -0.9242,  0.6274, -0.6122,  0.2593, -0.1688,  0.1323,  1.3099,
           0.7643, -0.6106, -0.1371,  0.7552, -0.5415,  0.7423]],

        [[ 0.4344,  0.3561, -0.4982, -0.4326, -0.2377,  0.8758, -1.0641,
          -0.3423,  0.3772, -0.4452,  0.6519, -0.2666, -0.0646,  0.7807,
           0.5272, -0.6573,  0.0976,  1.0804, -0.9153,  1.1743]],

        [[-0.5873,  0.2161, -0.6100, -0.0774, -0.5958,  0.6877, -0.2677,
          -0.4463,  0.5172,  1.1137,  0.1040, -0.3007, -1.1872, -0.2548,
           0.6388,  0.0914,  0.0290, -0.0519,  0.3834, -0.1068]],

        [[ 0.7888,  0.7641, -0.1245, -0.0249,  0.1865, -0.0143, -0.9726,
          -0.9177,  0.3889, -0.3268,  0.1702, -0.8667,  0.2136,  0.3825,
           0.2635, -0.9925, -0.4424,  1.0395, -1.2948,  0.7781]]],
       grad_fn=<ViewBackward0>) torch.Size([4, 1, 20])
