In [2]:
import torch.nn as nn
import torch
import math

In [3]:
E = 2
n_head = 1
query = torch.tensor([[[0.7204, 0.0731],
          [0.9699, 0.1078],
          [0.8829, 0.4132]]])

key = query.clone()
value = query.clone()

### Getting the memory (key, value) from the Encoder Layer


In [4]:
encoder_layer = nn.TransformerEncoderLayer(d_model=E, nhead=1,batch_first=True,dropout=0,dim_feedforward=2)
zeros_tensor = torch.nn.Parameter(torch.zeros(2))
encoder_layer.linear1.bias = zeros_tensor
encoder_layer.linear2.bias = zeros_tensor
encoder_output= encoder_layer(query)

In [5]:
encoder_output

tensor([[[ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000]]], grad_fn=<NativeLayerNormBackward0>)

### Defining the Decorder

In [6]:
decoder_layer = nn.TransformerDecoderLayer(d_model=E, nhead=1,batch_first=True,dropout=0,dim_feedforward=2)
zeros_tensor = torch.nn.Parameter(torch.zeros(2))
decoder_layer.linear1.bias = zeros_tensor
decoder_layer.linear2.bias = zeros_tensor
decoder_output = decoder_layer(query, encoder_output)

In [7]:
decoder_layer.state_dict()['self_attn.in_proj_weight']

tensor([[-0.0700,  0.2625],
        [ 0.6664,  0.3238],
        [ 0.5044, -0.6509],
        [-0.6687, -0.6628],
        [-0.5943, -0.6812],
        [ 0.4039, -0.2455]])

### MHA One  and Normalization

In [8]:
##### SA
weight = decoder_layer.state_dict()['self_attn.in_proj_weight']
q,k,v =  query.matmul(weight.t()).chunk(3, dim=-1)
v = v.reshape(3,2)
q = q.reshape(3,2)
k = k.reshape(3,2)
q = q / math.sqrt(E)
attn1 = torch.matmul(q,k.T)
m = nn.Softmax(dim=-1)
attn = m(attn1)
output = torch.matmul(attn, v)
output = output.reshape((3,2))
out_proj = decoder_layer.state_dict()['self_attn.out_proj.weight']
output_last = output.matmul(out_proj.t())
x = key.reshape((3,2)) + output_last

###### NORM
copied_data = x.clone()
layer_norm1 = nn.LayerNorm(2)
x = layer_norm1(x)

# x is going to be the query

In [9]:
decoder_layer.state_dict()['multihead_attn.in_proj_weight'].T

tensor([[-0.0170, -0.4485, -0.3418, -0.8304,  0.6771, -0.0114],
        [ 0.7792, -0.5596,  0.4939,  0.7070, -0.4128,  0.2273]])

In [10]:
 decoder_layer.state_dict()['multihead_attn.in_proj_weight'].t().chunk(3, dim=-1)

(tensor([[-0.0170, -0.4485],
         [ 0.7792, -0.5596]]),
 tensor([[-0.3418, -0.8304],
         [ 0.4939,  0.7070]]),
 tensor([[ 0.6771, -0.0114],
         [-0.4128,  0.2273]]))

### MHA two  and Normalization

In [11]:
##### MHA
my_memory = encoder_output.clone()
weight = decoder_layer.state_dict()['multihead_attn.in_proj_weight'].t().chunk(3, dim=-1)
# q is from above , k and v are the encoder outputs
qw = weight[0]
kw = weight[1]
vw = weight[2]
q = torch.matmul(x,qw)
k = torch.matmul(my_memory,kw)
v = torch.matmul(my_memory,vw)
q = q.reshape(3,2)
k = k.reshape(3,2)
v = v.reshape(3,2)  ##
q = q / math.sqrt(E)
attn1 = torch.matmul(q,k.T)
m = nn.Softmax(dim=-1)
my_attn_output_weights = m(attn1)
output = torch.matmul(my_attn_output_weights, v)
output = output.reshape((3,2))
out_proj = decoder_layer.state_dict()['multihead_attn.out_proj.weight']
my_output = output.matmul(out_proj.t())
x = x + my_output

#### NORM
copied_data = x.clone()
layer_norm1 = nn.LayerNorm(2)
x = layer_norm1(x)


### Multiple FCC Layers and Normalization

In [12]:
#### FC
b = x.squeeze()
linear1 = torch.matmul(b,decoder_layer.state_dict()['linear1.weight'].T)
m = nn.ReLU()
linear1_relu = m(linear1)
linear2 = torch.matmul(linear1_relu,decoder_layer.state_dict()['linear2.weight'].T)
x = x + linear2

layer_norm2 = nn.LayerNorm(2)
x = layer_norm2(x)

In [13]:
decoder_output

tensor([[[ 1.0000, -1.0000],
         [ 1.0000, -1.0000],
         [ 1.0000, -1.0000]]], grad_fn=<NativeLayerNormBackward0>)

In [14]:
x  #out result

tensor([[ 1.0000, -1.0000],
        [ 1.0000, -1.0000],
        [ 1.0000, -1.0000]], grad_fn=<NativeLayerNormBackward0>)