In [1]:
from math import sqrt

import torch
import torch.nn as nn

In [2]:
class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v):
        super(Self_Attention, self).__init__()
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
        Q = self.q(x)  # Q: batch_size * seq_len * dim_k
        K = self.k(x)  # K: batch_size * seq_len * dim_k
        V = self.v(x)  # V: batch_size * seq_len * dim_v

        atten = (
            nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1)))
            * self._norm_fact
        )  # Q * K.T() # batch_size * seq_len * seq_len

        output = torch.bmm(
            atten, V
        )  # Q * K.T() * V # batch_size * seq_len * dim_v

        return output


In [3]:
X = torch.randn(4,3,2)

In [4]:
X

tensor([[[-0.4395,  0.5981],
         [-0.4380, -0.5342],
         [ 1.3244,  0.2491]],

        [[ 0.5067,  0.0293],
         [ 1.2101,  0.5148],
         [-0.5870, -0.4357]],

        [[-1.1413,  1.5060],
         [ 0.4430,  1.0895],
         [ 2.2553,  0.8129]],

        [[-1.5835, -1.2745],
         [ 1.1963,  0.6706],
         [ 0.5570,  0.5289]]])

In [5]:
self_attention = Self_Attention(2,4,5)

In [6]:
res = self_attention(X)

In [7]:
res

tensor([[[-3.5565e-01,  1.0159e-01, -2.5908e-01, -1.7877e-01, -4.5910e-02],
         [-3.5945e-01,  8.2470e-02, -2.5623e-01, -2.1338e-01, -3.2045e-02],
         [-3.6035e-01,  4.7512e-02, -2.4232e-01, -2.9357e-01, -7.0854e-04]],

        [[-3.7625e-01,  3.7725e-02, -2.6096e-01, -2.7229e-01, -7.4061e-03],
         [-3.7055e-01,  3.3232e-02, -2.5079e-01, -2.9885e-01,  2.3704e-03],
         [-3.8739e-01,  4.6998e-02, -2.8102e-01, -2.1925e-01, -2.6959e-02]],

        [[ 1.1833e-01,  5.2782e-01,  2.3766e-01, -4.9412e-01,  2.9493e-02],
         [-8.7404e-03,  3.2137e-01,  1.4459e-01, -6.2764e-01,  9.4245e-02],
         [-1.6882e-01,  5.5810e-02,  2.9727e-02, -8.0881e-01,  1.8087e-01]],

        [[-5.2465e-01,  7.9635e-02, -4.9276e-01,  2.4125e-01, -1.9264e-01],
         [-4.1031e-01,  7.6446e-02, -3.2682e-01, -8.5592e-02, -7.6742e-02],
         [-4.5395e-01,  7.8101e-02, -3.9034e-01,  4.0177e-02, -1.2138e-01]]],
       grad_fn=<BmmBackward0>)

In [8]:
res.shape

torch.Size([4, 3, 5])

In [9]:
class Self_Attention_Muti_Head(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self,input_dim,dim_k,dim_v,nums_head):
        super(Self_Attention_Muti_Head,self).__init__()
        assert dim_k % nums_head == 0
        assert dim_v % nums_head == 0
        self.q = nn.Linear(input_dim,dim_k)
        self.k = nn.Linear(input_dim,dim_k)
        self.v = nn.Linear(input_dim,dim_v)
        
        self.nums_head = nums_head
        self.dim_k = dim_k
        self.dim_v = dim_v
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        Q = self.q(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) 
        K = self.k(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.nums_head) 
        V = self.v(x).reshape(-1,x.shape[0],x.shape[1],self.dim_v // self.nums_head)
        print(x.shape)
        print(Q.size())

        atten = nn.Softmax(dim=-1)(torch.matmul(Q,K.permute(0,1,3,2))) # Q * K.T() # batch_size * seq_len * seq_len
        
        output = torch.matmul(atten,V).reshape(x.shape[0],x.shape[1],-1) # Q * K.T() * V # batch_size * seq_len * dim_v
        return output    