<a href="https://colab.research.google.com/github/NeoLin1103/Machine-learning-algorithms/blob/main/Self_attention_layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
'''
I: (Sequence length, Embedding dimension)
A: (Sequence length, Sequence length)
O: (Output embedding dimension, Sequence length)
W: (Output embedding dimension, Output embedding dimension)
 
I^T * Wq^T = Q^T
I^T * Wk^T = K^T
K^T * (Q^T)^T = A
softmax(A) = A'
I^T * Wv^T = V^T
(V^T)^T * A' = O
 
Q^T * (K^T)^T = A^T
softmax(A^T, dim=-1) = A'^T
A'^T * V^T = O^T
'''

In [2]:
seq_length = 15
embed_dim = 42
Input = torch.rand(seq_length,embed_dim) #(Sequence length, Embedding dimension)
 
W_query = nn.Linear(embed_dim,embed_dim)
W_key = nn.Linear(embed_dim,embed_dim)
W_value = nn.Linear(embed_dim,embed_dim)
 
Q = W_query(Input)
K = W_key(Input)
V = W_value(Input)
 
A = K @ Q.T
A = A/(embed_dim**0.5)
A_trans = F.softmax(A,dim=0)
 
O = V.T @ A_trans
print(O.shape)
print(O)

torch.Size([42, 15])
tensor([[ 0.4532,  0.4561,  0.4547,  0.4507,  0.4524,  0.4547,  0.4514,  0.4564,
          0.4519,  0.4509,  0.4518,  0.4541,  0.4503,  0.4541,  0.4565],
        [-0.3864, -0.3809, -0.3831, -0.3872, -0.3884, -0.3841, -0.3869, -0.3791,
         -0.3846, -0.3888, -0.3874, -0.3841, -0.3887, -0.3855, -0.3857],
        [ 0.0479,  0.0470,  0.0469,  0.0501,  0.0465,  0.0462,  0.0479,  0.0456,
          0.0475,  0.0479,  0.0490,  0.0447,  0.0473,  0.0482,  0.0473],
        [-0.1518, -0.1495, -0.1508, -0.1506, -0.1522, -0.1528, -0.1522, -0.1505,
         -0.1524, -0.1513, -0.1497, -0.1524, -0.1538, -0.1519, -0.1484],
        [ 0.1579,  0.1580,  0.1591,  0.1570,  0.1571,  0.1605,  0.1592,  0.1584,
          0.1595,  0.1569,  0.1560,  0.1587,  0.1595,  0.1597,  0.1560],
        [-0.0271, -0.0255, -0.0250, -0.0279, -0.0263, -0.0227, -0.0257, -0.0245,
         -0.0249, -0.0279, -0.0280, -0.0242, -0.0250, -0.0260, -0.0288],
        [ 0.2860,  0.2847,  0.2854,  0.2856,  0.2858,  

In [3]:
class Attention(nn.Module):
  def __init__(self,embed_dim_in,embed_dim_out):
    super(Attention, self).__init__()
    self.W_query = nn.Linear(embed_dim_in,embed_dim_out)
    self.W_key = nn.Linear(embed_dim_in,embed_dim_out)
    self.W_value = nn.Linear(embed_dim_in,embed_dim_out)
    self.Softmax = nn.Softmax(dim=0)
 
  def forward(self,Input):
    Q = self.W_query(Input)
    K = self.W_key(Input)
    V = self.W_value(Input)
 
    A = K @ Q.T
    #A = A/(embed_dim**0.5)
    A_trans = self.Softmax(A)
 
    O = V.T @ A_trans
    return O

In [5]:
layer = Attention(embed_dim_in=42,embed_dim_out=42)

In [6]:
seq_length = 15
embed_dim = 42
Input = torch.rand(seq_length,embed_dim)
out = layer.forward(Input)
print(out.shape)
print(out)

torch.Size([42, 15])
tensor([[ 3.5335e-01,  3.5659e-01,  3.5526e-01,  3.5760e-01,  3.7091e-01,
          3.6311e-01,  3.6275e-01,  3.6405e-01,  3.5912e-01,  3.5042e-01,
          3.5503e-01,  3.5999e-01,  3.6633e-01,  3.6155e-01,  3.6000e-01],
        [ 4.0152e-01,  3.9871e-01,  3.9765e-01,  3.8952e-01,  3.9221e-01,
          3.9234e-01,  3.9850e-01,  3.9024e-01,  3.9429e-01,  3.9609e-01,
          3.8597e-01,  3.9263e-01,  3.9768e-01,  4.1125e-01,  3.9644e-01],
        [ 3.0265e-02,  2.2602e-02,  2.4049e-02,  1.2744e-02,  2.9362e-02,
          3.7751e-02,  2.4223e-02,  3.4803e-02,  2.1703e-02, -3.1042e-04,
          1.9086e-02,  9.3021e-03,  2.0091e-02,  2.5775e-02,  1.5490e-02],
        [-6.8744e-01, -6.9069e-01, -6.8718e-01, -6.8736e-01, -6.8936e-01,
         -6.8934e-01, -6.9296e-01, -6.8493e-01, -6.8861e-01, -6.8208e-01,
         -6.8327e-01, -6.8760e-01, -6.9109e-01, -6.9488e-01, -6.8686e-01],
        [ 1.9361e-01,  1.9071e-01,  1.9421e-01,  1.9502e-01,  1.8507e-01,
          1.7