In [185]:
import torch
import torch.nn as nn
import math

In [10]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

In [11]:
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

In [132]:
src = torch.rand(32, 40, 768)

In [225]:
class MultiHeadSelfAttn(nn.Module):
    def __init__(self, dmodel, output_dim, heads=8):
        '''
        dmodel: should be equal to the embedding_dimension
        output: what the size of the output embedding should be (default is same as dmodel)
        '''
        
        super(MultiHeadSelfAttn, self).__init__()
        self.dmodel = dmodel
        self.heads = heads
        self.output_dim = output_dim
        assert dmodel % heads == 0
        
        # dk is the size of each of the linear projections of the embedding
        self.dk = dmodel // 8
        
        # These are the parameters to project the matrix to the amount of heads
        self.key_projections = nn.Linear(self.dmodel, self.dmodel)
        self.value_projections = nn.Linear(self.dmodel, self.dmodel)
        self.query_projections = nn.Linear(self.dmodel, self.dmodel)
        
        # The final linear layer
        self.end_linear = nn.Linear(self.dmodel, self.output_dim)
        
        # Softmax
        self.softmax = nn.Softmax(dim=1)
        
    def scaled_attention(self, query, keys, values):
        attn = torch.matmul(query, keys.transpose(-2, -1))
        attn = attn / math.sqrt(self.dk)
        attn = self.softmax(attn)
        attn = torch.matmul(attn.transpose(-2, -1), values)
        return attn
        
    def forward(self, embeddings):
        '''
        embeddings: should be of dimensions (batch, sequence_length, embedding_dimension)
        '''
        batches, sequence_length, embeddings_dim = embeddings.size()
        
        # Get the query projections
        query = self.query_projections(embeddings)
        query = query.view(batches, self.heads, sequence_length, self.dk)
        
        # Get the key projections
        keys = self.key_projections(embeddings)
        keys = keys.view(batches, self.heads, sequence_length, self.dk)
        
        # Get the value projections
        values = self.value_projections(embeddings)
        values = values.view(batches, self.heads, sequence_length, self.dk)
        
        attn_out = self.scaled_attention(query, keys, values)
        attn_out = attn_out.view(batches, sequence_length, self.dmodel)
        
        return self.end_linear(attn_out)
        

In [226]:
multi_attn = MultiHeadSelfAttn(768, 768)

In [227]:
x = torch.randn(32, 40, 768)

In [229]:
multi_attn(x).size()

torch.Size([32, 40, 768])