In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

In [2]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        return self.embedding(x)

In [3]:
# for length of 12
Embedding(100, 512).forward(torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])).shape

torch.Size([2, 12, 512])

In [37]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_length, embedding_dim):
        super(PositionalEncoding, self).__init__()
        self.embedding_dim = embedding_dim
        self.pv = torch.zeros(seq_length, self.embedding_dim)
        self.seq_length = seq_length

    def forward(self, embedded_x):
        """
            Function to return positional encoding with combination of word embeddings
            and positional vector
        Arguments:
            pe: Positional vector variable of that shape
            self.embedding_dim: Dimension of the vector required for the model
            pos_encoding: Variable to hold final positional Encoding of the word
        """
        seq_length = embedded_x.shape[1]
        pos_encoding = torch.zeros_like(embedded_x)
        for batch_idx, batch in enumerate(pos_encoding):
            # print(f'batch shape: {batch.shape}')
            for pos, emb in enumerate(batch):
                emb = emb.reshape(1, -1)
                for i in range(0,self.embedding_dim, 2):
                    self.pv[0][i] = np.sin(pos/10000**((2*i)/self.embedding_dim))
                    # print(pos, i, pos_encoding.shape)
                    # print(pos_encoding[pos][i])

                    # Value of the word embedding is increased by the product with sqrt of 512
                    # Simple addition of the positional information to the word embedding
                    pos_encoding[batch_idx][pos][i] = (emb[0][i] * np.sqrt(self.embedding_dim)) + self.pv[0][i]

                    self.pv[0][i+1] = np.cos(pos/10000**((2*i)/self.embedding_dim))
                    pos_encoding[batch_idx][pos][i+1] = (emb[0][i+1] * np.sqrt(self.embedding_dim)) + self.pv[0][i+1]
        return pos_encoding


In [38]:
PositionalEncoding(10, 512).forward(torch.rand(10,12,512)).shape

torch.Size([10, 12, 512])

In [40]:
# # PositionalEncoding(10, 512).forward(torch.rand((10,512))).shape
# rand_ = torch.rand(10, 10, 512)
# # embeddings_temporary = torch.tensor([PositionalEncoding(10, 512).forward(r) for r in rand_])
# embeddings_temporary = torch.tensor([])
# for r in rand_:
#     for m in r:
#         print(m.reshape(1,-1).shape)
#         break

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim=512, num_heads=8):
        super(MultiHeadAttention, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        self.single_head_dim = int(self.embedding_dim/self.num_heads)

        self.query_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)
        self.key_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)
        self.value_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)

        self.out = nn.Linear(self.num_heads * self.single_head_dim, self.embedding_dim)

    def forward(self, key, query, value, mask=None):
        # [BS, seq_len, embedding_dim]
        batch_size = key.size(0)
        seq_length = key.size(1)

        # seq length can vary in decoder since it comes from lower decoder, not encoder
        seq_length_query = query.size(1)

        # reshaping it as [BS, seq_length, num_heads, single_head_dim]
        # embedding_dim = num_heads * single_head_dim
        # each word has 512 dim, 64 dim in each head
        # whole sequence has 10 words, so!
        key = key.view(batch_size, seq_length, self.num_heads, self.single_head_dim)
        query = query.view(batch_size, seq_length, self.num_heads, self.single_head_dim)
        value = value.view(batch_size, seq_length, self.num_heads, self.single_head_dim)

        k = self.key_mat(key) #[32x8x64x10]
        q = self.query_mat(query) #[32x8x64x10]
        v = self.value_mat(value) #[32x8x64x10]

        q = q.transpose(1, 2) #[32x8x10x64]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # for transpose of k
        k_transposed = k.transpose(-1,-2) # 32x8x10x64

        product = torch.matmul(q, k_transposed) # 32x8x10x10

        if mask is not None:
            pass

        product = product / math.sqrt(self.single_head_dim)
        print(f'Shape of product is {product.shape}')
        # softmax scores
        softmax_scores = F.softmax(product, dim=1) # 32x8x10x10

        # final attention scores
        scores = torch.matmul(softmax_scores, v) # 32x8x10x64

        concat = scores.transpose(1,2).contiguous().view(batch_size, seq_length_query, self.single_head_dim * self.num_heads)
        #32x10x512

        output = self.out(concat) # 32x10x512

        return output

In [42]:
mha = MultiHeadAttention()

In [43]:
rand = torch.randn(32,10,512)
mha.forward(rand, rand, rand).shape

Shape of product is torch.Size([32, 8, 10, 10])


torch.Size([32, 10, 512])

In [44]:
class EncoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads = 8):
        super(EncoderBlock, self).__init__()

        self.attention = MultiHeadAttention(embedding_dim, num_heads)

        self.norm = nn.LayerNorm(embedding_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim)
        )

        self.dropout = nn.Dropout(0.2)

    def forward(self, key, query, value):
        attention_scores = self.attention(key, query, value)
        attention_with_residual = attention_scores + value # 32x10x512
        norm1_out = self.dropout(self.norm(attention_with_residual))

        ff_out = self.feed_forward(norm1_out)

        ff_out_with_residual = ff_out + norm1_out

        norm2_out = self.dropout(self.norm(ff_out_with_residual))

        return norm2_out

In [45]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, seq_length, embedding_dim, num_layers=2, num_heads=8):
        super(TransformerEncoder, self).__init__()

        self.embedding_layer = Embedding(vocab_size, embedding_dim)
        self.positional_encoder = PositionalEncoding(seq_length, embedding_dim)
        self.layers = nn.ModuleList(
            [EncoderBlock(embedding_dim, num_heads) for i in range(num_layers)]
        )

    def forward(self, x):
        embedding_output = self.embedding_layer(x)
        print(f'Embedding output shape is: {embedding_output.shape}')
        positional_encoding = self.positional_encoder(embedding_output)

        # for first encoder that takes in the positional encoding vector
        out = positional_encoding
        for layer in self.layers:
            out = layer(out, out, out)

        return out

In [46]:
te = TransformerEncoder(vocab_size=100, seq_length=10, embedding_dim=512, num_layers=2, num_heads=8)
te

TransformerEncoder(
  (embedding_layer): Embedding(
    (embedding): Embedding(100, 512)
  )
  (positional_encoder): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x EncoderBlock(
      (attention): MultiHeadAttention(
        (query_mat): Linear(in_features=64, out_features=64, bias=False)
        (key_mat): Linear(in_features=64, out_features=64, bias=False)
        (value_mat): Linear(in_features=64, out_features=64, bias=False)
        (out): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
)

In [47]:
source = torch.tensor([[3,21,4,6,1,3,5,34,3,1,3,5,1,3,67,8,3,67,4],
                       [3,43,3,5,1,3,67,8,3,67,4,34,23,4,1,3,4,65,2]])
target = torch.randn(1,10)
source.shape , target.shape

(torch.Size([2, 19]), torch.Size([1, 10]))

In [49]:
te.forward(source).shape

Embedding output shape is: torch.Size([2, 19, 512])
Shape of product is torch.Size([2, 8, 19, 19])
Shape of product is torch.Size([2, 8, 19, 19])


torch.Size([2, 19, 512])