In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F



In [15]:
class SelfAttention(nn.Module):
    def __init__(self,d_model):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.key = nn.Linear(d_model,d_model)
        self.query = nn.Linear(d_model,d_model)
        self.value = nn.Linear(d_model,d_model)
        self.softmax = nn.Softmax(dim=2) #along 3rd dimension, so the sum of similarities for a given query is 1

    def forward(self,x):
        #x.shape = [m batches, n words in a sequence, d_model features in a word (length of embedding)]
        keys = self.key(x)
        queries = self.query(x)
        values = self.value(x)

        scores = torch.matmul(queries, keys.transpose(1,2))/self.d_model**.5 #n x n
        attention = self.softmax(scores)
        weighted = torch.matmul(attention, values) #n x d_model , torch.bmm is batch matmul for 3 dim
        return weighted


attention = SelfAttention(d_model=4)
x = torch.tensor([
        [[1., 2., 3., 4.],
         [2., 3., 4., 5.],
         [3., 4., 5., 6.]],

        [[4., 3., 2., 1.],
         [5., 4., 3., 2.],
         [6., 5., 4., 3.]]
])
weighted = attention.forward(x)
weighted

tensor([[[-2.2137,  1.8800, -0.6150, -4.0573],
         [-2.1646,  1.7778, -0.6503, -3.9122],
         [-2.1217,  1.6885, -0.6811, -3.7853]],

        [[ 0.1410,  3.1941,  1.7682, -3.3440],
         [ 0.1855,  3.1015,  1.7362, -3.2124],
         [ 0.2231,  3.0232,  1.7091, -3.1012]]], grad_fn=<BmmBackward0>)

In [3]:
#same computation overhead as self attention but more semantic representations

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0, "d_k=d_v=d_model/h"
        self.h = h #num_heads
        self.d_model = d_model #length of embeddings
        self.d_k = self.d_v = int(self.d_model / self.h)
        
        # self.query = nn.Linear(d_model,d_model)
        # self.key = nn.Linear(d_model,d_model)
        # self.value = nn.Linear(d_model,d_model)

        self.W_Q = nn.Linear(d_model,d_model)
        self.W_K = nn.Linear(d_model,d_model)
        self.W_V = nn.Linear(d_model,d_model)

        self.W_O = nn.Linear(d_model,d_model)
        
        
        
    # def computeQKV(self,x):
    #     batch_size, seq_length, d_model = x.size()
    #     Q = self.query(x)
    #     K = self.key(x)
    #     V = self.value(x)
    #     return Q,K,V
        
    def forward(self,Q,K,V):
        batch_size, seq_length, d_model = Q.size()
        
        queries = self.W_Q(Q)
        keys = self.W_K(K)
        values = self.W_V(V)

        queries = queries.reshape(batch_size, self.h, seq_length, self.d_k) #(batch size, seq_length, d_model) ==> (batch size, seq_length, h, d_k)
        keys = keys.reshape(batch_size, self.h, seq_length, self.d_k)
        values = values.reshape(batch_size, self.h, seq_length, self.d_v)

        queries = queries.transpose(1,2) #(batch size, h, seq_length, d_k), flips the dims 1 and 2
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)

        scores = torch.matmul(queries, keys.transpose(-2,-1))/self.d_k**.5
        attention = F.softmax(scores, dim=-1)
        weighted = torch.matmul(attention, values)
        concat = weighted.reshape(batch_size, seq_length, d_model)
        out = self.W_O(concat) #W_O
        return out

x = torch.tensor([
        [[1., 2., 3., 4.],
         [2., 3., 4., 5.],
         [3., 4., 5., 6.]],

        [[4., 3., 2., 1.],
         [5., 4., 3., 2.],
         [6., 5., 4., 3.]]
])
mha = MultiHeadAttention(d_model=4,h=2)
Q,K,V = x,x,x
# Q,K,V = mha.computeQKV(x), not needed because it is a double linear transformation so does the same thing X-->Q=WX-->W_Q*Q as X--> W_Q*X
mha.forward(Q,K,V)

tensor([[[-1.5894, -0.5033, -0.5218, -0.4477],
         [-2.6329, -0.1831, -0.1079,  0.8398],
         [-2.2823, -0.5435, -0.5072, -0.3413]],

        [[-0.7238, -0.0269, -0.1612,  0.4117],
         [-0.7842, -0.4719, -0.5143, -0.7158],
         [-1.0875, -0.0605, -0.1583,  0.3944]]], grad_fn=<ViewBackward0>)

In [20]:
Q

tensor([[[-1.5914,  1.0629, -2.4170, -1.6738],
         [-2.1618,  0.9891, -2.8905, -2.2325],
         [-2.7321,  0.9153, -3.3640, -2.7912]],

        [[-0.4519, -1.0591,  0.3878, -1.1029],
         [-1.0222, -1.1328, -0.0857, -1.6616],
         [-1.5925, -1.2066, -0.5592, -2.2203]]], grad_fn=<ViewBackward0>)