In [41]:
from math import sqrt
import torch
from torch import nn


class Self_Attention(torch.nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self,input_dim,dim_k,dim_v):
        super(Self_Attention,self).__init__()
        self.q = torch.nn.Linear(input_dim,dim_k)
        self.k = torch.nn.Linear(input_dim,dim_k)
        self.v = torch.nn.Linear(input_dim,dim_v)
        self._norm_fact = 1 / sqrt(dim_k)
        
    
    def forward(self,x):
        Q = self.q(x) # Q: batch_size * seq_len * dim_k
        K = self.k(x) # K: batch_size * seq_len * dim_k
        V = self.v(x) # V: batch_size * seq_len * dim_v
        print('before: ', torch.nn.Softmax(dim=1)(torch.bmm(Q,K.permute(0,2,1))) )
        atten = torch.nn.Softmax(dim=1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_len
        print('atten: ', atten)
        output = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_v
        
        return output

In [42]:
x = torch.rand((4,3,2))
model = Self_Attention(2, 4, 5)
res = model(x)
res.shape

before:  tensor([[[0.3070, 0.3090, 0.2706],
         [0.3013, 0.3037, 0.2582],
         [0.3917, 0.3874, 0.4712]],

        [[0.3149, 0.2928, 0.2381],
         [0.3272, 0.3274, 0.3223],
         [0.3580, 0.3798, 0.4395]],

        [[0.3581, 0.3378, 0.3725],
         [0.2882, 0.3121, 0.2705],
         [0.3537, 0.3502, 0.3570]],

        [[0.3610, 0.3407, 0.3819],
         [0.2546, 0.2936, 0.2106],
         [0.3845, 0.3657, 0.4075]]], grad_fn=<SoftmaxBackward>)
atten:  tensor([[[0.1535, 0.1545, 0.1353],
         [0.1506, 0.1518, 0.1291],
         [0.1958, 0.1937, 0.2356]],

        [[0.1574, 0.1464, 0.1191],
         [0.1636, 0.1637, 0.1612],
         [0.1790, 0.1899, 0.2198]],

        [[0.1790, 0.1689, 0.1862],
         [0.1441, 0.1560, 0.1353],
         [0.1769, 0.1751, 0.1785]],

        [[0.1805, 0.1703, 0.1910],
         [0.1273, 0.1468, 0.1053],
         [0.1922, 0.1829, 0.2038]]], grad_fn=<MulBackward0>)


torch.Size([4, 3, 5])

In [19]:
x = torch.rand((1,2,2))
torch.nn.Softmax(dim=-1)(x)

tensor([[[0.5801, 0.4199],
         [0.3173, 0.6827]]])

In [40]:
import numpy as np
np.sum([[0.3323, 0.1323, 0.5354],
        [0.1261, 0.4239, 0.4500]])

2.0

In [39]:
m = nn.Softmax(dim=-1)
input = torch.randn(2, 3)
output = m(input)
output

tensor([[0.3323, 0.1323, 0.5354],
        [0.1261, 0.4239, 0.4500]])