In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

In [74]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        return self.embedding(x)

In [75]:
torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]]).shape 

torch.Size([2, 12])

In [76]:
# for length of 12
Embedding(100, 512).forward(torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])).shape

torch.Size([2, 12, 512])

In [77]:
Embedding(50, 64).forward(torch.randint(2, 12, (2,12))).shape

torch.Size([2, 12, 64])

In [78]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_length, embedding_dim):
        super(PositionalEncoding, self).__init__()
        self.embedding_dim = embedding_dim
        self.pv = torch.zeros(seq_length, self.embedding_dim)
        self.seq_length = seq_length

    def forward(self, embedded_x):
        """
            Function to return positional encoding with combination of word embeddings
            and positional vector
        Arguments:
            pe: Positional vector variable of that shape
            self.embedding_dim: Dimension of the vector required for the model
            pos_encoding: Variable to hold final positional Encoding of the word
        """
        seq_length = embedded_x.shape[1]
        pos_encoding = torch.zeros_like(embedded_x)
        for batch_idx, batch in enumerate(pos_encoding):
            # print(f'batch shape: {batch.shape}')
            for pos, emb in enumerate(batch):
                emb = emb.reshape(1, -1)
                for i in range(0,self.embedding_dim, 2):
                    self.pv[0][i] = np.sin(pos/10000**((2*i)/self.embedding_dim))
                    # print(pos, i, pos_encoding.shape)
                    # print(pos_encoding[pos][i])

                    # Value of the word embedding is increased by the product with sqrt of 512
                    # Simple addition of the positional information to the word embedding
                    pos_encoding[batch_idx][pos][i] = (emb[0][i] * np.sqrt(self.embedding_dim)) + self.pv[0][i]

                    self.pv[0][i+1] = np.cos(pos/10000**((2*i)/self.embedding_dim))
                    pos_encoding[batch_idx][pos][i+1] = (emb[0][i+1] * np.sqrt(self.embedding_dim)) + self.pv[0][i+1]
        return pos_encoding


In [79]:
PositionalEncoding(10, 512).forward(torch.rand(10,12,512)).shape

torch.Size([10, 12, 512])

In [80]:
# # PositionalEncoding(10, 512).forward(torch.rand((10,512))).shape
# rand_ = torch.rand(10, 10, 512)
# # embeddings_temporary = torch.tensor([PositionalEncoding(10, 512).forward(r) for r in rand_])
# embeddings_temporary = torch.tensor([])
# for r in rand_:
#     for m in r:
#         print(m.reshape(1,-1).shape)
#         break

In [133]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim=512, num_heads=8):
        super(MultiHeadAttention, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        self.single_head_dim = int(self.embedding_dim/self.num_heads)

        self.query_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)
        self.key_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)
        self.value_mat = nn.Linear(self.single_head_dim, self.single_head_dim, bias=False)

        self.out = nn.Linear(self.num_heads * self.single_head_dim, self.embedding_dim)

    def forward(self, key, query, value, mask=None):
        # [BS, seq_len, embedding_dim]
        batch_size = key.size(0)
        seq_length = key.size(1)

        # seq length can vary in decoder since it comes from lower decoder, not encoder
        seq_length_query = query.size(1)

        # reshaping it as [BS, seq_length, num_heads, single_head_dim]
        # embedding_dim = num_heads * single_head_dim
        # each word has 512 dim, 64 dim in each head
        # whole sequence has 10 words, so!
        key = key.view(batch_size, seq_length, self.num_heads, self.single_head_dim)
        query = query.view(batch_size, seq_length, self.num_heads, self.single_head_dim)
        value = value.view(batch_size, seq_length, self.num_heads, self.single_head_dim)

        k = self.key_mat(key) #[32x8x64x10]
        q = self.query_mat(query) #[32x8x64x10]
        v = self.value_mat(value) #[32x8x64x10]

        q = q.transpose(1, 2) #[32x8x10x64]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # for transpose of k
        k_transposed = k.transpose(-1,-2) # 32x8x10x64

        # scaled dot product attention
        # attention score!
        product = torch.matmul(q, k_transposed) # 32x8x10x10

        # if there is mask
        # replace the values in product, corresponding to the 0s in the same space in the mask tensor 
        # the product will have replaced the values in those locations with '-inf'
        # print(product[0,:,0,:20])
        print(f'Before mask, product shape: {product.shape}')
        if mask is not None:
            product = product.masked_fill(mask==0, float('-inf'))
        print(f'After mask, product shape: {product.shape}')        
        # print(product[0,:,0,:20])
        
        product = product / math.sqrt(self.single_head_dim)
        print(f'Shape of product is {product.shape}')

        # softmax to normalize the scores, so they sum to 1
        # softmax scores
        softmax_scores = F.softmax(product, dim=1) # 32x8x10x10

        # final attention scores
        # multiply each value vector with softmax scores
        scores = torch.matmul(softmax_scores, v) # 32x8x10x64

        concat = scores.transpose(1,2).contiguous().view(batch_size, seq_length_query, self.single_head_dim * self.num_heads)
        #32x10x512

        output = self.out(concat) # 32x10x512

        return output

In [134]:
64*8

512

In [135]:
mha = MultiHeadAttention()

In [136]:
rand = torch.randn(32,10,512)
mha.forward(rand, rand, rand).shape

Before mask, product shape: torch.Size([32, 8, 10, 10])
After mask, product shape: torch.Size([32, 8, 10, 10])
Shape of product is torch.Size([32, 8, 10, 10])


torch.Size([32, 10, 512])

In [137]:
class EncoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads = 8):
        super(EncoderBlock, self).__init__()

        self.attention = MultiHeadAttention(embedding_dim, num_heads)

        self.norm = nn.LayerNorm(embedding_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim)
        )

        self.dropout = nn.Dropout(0.2)

    def forward(self, key, query, value):
        attention_scores = self.attention(key, query, value)
        attention_with_residual = attention_scores + value # 32x10x512
        norm1_out = self.dropout(self.norm(attention_with_residual))

        ff_out = self.feed_forward(norm1_out)

        ff_out_with_residual = ff_out + norm1_out

        norm2_out = self.dropout(self.norm(ff_out_with_residual))

        return norm2_out

In [138]:
encoder_block = EncoderBlock(embedding_dim=512, num_heads=8)
x = torch.randn(2,19,512)
encoder_block.forward(x,x,x).shape

Before mask, product shape: torch.Size([2, 8, 19, 19])
After mask, product shape: torch.Size([2, 8, 19, 19])
Shape of product is torch.Size([2, 8, 19, 19])


torch.Size([2, 19, 512])

In [139]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, seq_length, embedding_dim, num_layers=2, num_heads=8):
        super(TransformerEncoder, self).__init__()

        self.embedding_layer = Embedding(vocab_size, embedding_dim)
        self.positional_encoder = PositionalEncoding(seq_length, embedding_dim)
        self.layers = nn.ModuleList(
            [EncoderBlock(embedding_dim, num_heads) for i in range(num_layers)]
        )

    def forward(self, x):
        embedding_output = self.embedding_layer(x)
        print(f'Embedding output shape is: {embedding_output.shape}')
        positional_encoding = self.positional_encoder(embedding_output)

        # for first encoder that takes in the positional encoding vector
        out = positional_encoding
        print(f'Out shape: {out.shape}')
        for layer in self.layers:
            out = layer(out, out, out)

        return out

In [140]:
te = TransformerEncoder(vocab_size=100, seq_length=10, embedding_dim=512, num_layers=2, num_heads=8)
te

TransformerEncoder(
  (embedding_layer): Embedding(
    (embedding): Embedding(100, 512)
  )
  (positional_encoder): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x EncoderBlock(
      (attention): MultiHeadAttention(
        (query_mat): Linear(in_features=64, out_features=64, bias=False)
        (key_mat): Linear(in_features=64, out_features=64, bias=False)
        (value_mat): Linear(in_features=64, out_features=64, bias=False)
        (out): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=512, bias=True)
      )
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
)

In [141]:
source = torch.tensor([[3,21,4,6,1,3,5,34,3,1,3,5,1,3,67,8,3,67,4],
                       [3,43,3,5,1,3,67,8,3,67,4,34,23,4,1,3,4,65,2]])
target = torch.randn(1,10)
source.shape , target.shape

(torch.Size([2, 19]), torch.Size([1, 10]))

In [142]:
te.forward(source).shape

Embedding output shape is: torch.Size([2, 19, 512])
Out shape: torch.Size([2, 19, 512])
Before mask, product shape: torch.Size([2, 8, 19, 19])
After mask, product shape: torch.Size([2, 8, 19, 19])
Shape of product is torch.Size([2, 8, 19, 19])
Before mask, product shape: torch.Size([2, 8, 19, 19])
After mask, product shape: torch.Size([2, 8, 19, 19])
Shape of product is torch.Size([2, 8, 19, 19])


torch.Size([2, 19, 512])

### Decoder of Transformer 

In [143]:
mha_decoder = MultiHeadAttention()
x = torch.randn(32,10,512)
# x will be the q, k and v for now 

In [144]:
# creating a mask, upper triangular for autoregressive behaviour
target_mask = torch.triu(torch.ones((10,10)),diagonal=1)
print(f'Shape of target_mask: {target_mask.shape}')
print(target_mask)

Shape of target_mask: torch.Size([10, 10])
tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [145]:
target_mask.unsqueeze(0).shape, target_mask.unsqueeze(0).unsqueeze(0).shape

(torch.Size([1, 10, 10]), torch.Size([1, 1, 10, 10]))

In [146]:
target_mask = target_mask.unsqueeze(0).unsqueeze(0)
print(target_mask)
target_mask = target_mask==0
print(target_mask.shape)
target_mask 

tensor([[[[0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])
torch.Size([1, 1, 10, 10])


tensor([[[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]]])

In [147]:
mha_decoder.forward(x, x, x, target_mask).shape 

Before mask, product shape: torch.Size([32, 8, 10, 10])
After mask, product shape: torch.Size([32, 8, 10, 10])
Shape of product is torch.Size([32, 8, 10, 10])


torch.Size([32, 10, 512])

In [148]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads=8):
        super(DecoderBlock, self).__init__()

        self.self_attention = MultiHeadAttention(embedding_dim, num_heads)
        self.cross_attention = MultiHeadAttention(embedding_dim, num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, 4*embedding_dim),
            nn.ReLU(),
            nn.Linear(4*embedding_dim, embedding_dim)
        )

        self.norm = nn.LayerNorm(embedding_dim)

        self.dropout = nn.Dropout(0.2)

    def forward(self, key, query, value, enc_output=None, mask=None):
        masked_attention_scores = self.self_attention(key, query, value, mask)
        attention_with_residual = masked_attention_scores + value 
        norm_out = self.dropout(self.norm(attention_with_residual))
        print(f'Norm out shape: {norm_out.shape}')
        # key: key_encoder, query:query_decoder, value:value_encoder
        cross_attention_scores = self.cross_attention(enc_output, norm_out, enc_output, mask)
        cross_attention_with_residual = cross_attention_scores + norm_out 
        norm2_out = self.dropout(self.norm(cross_attention_with_residual))
        
        ff_out = self.ffn(norm2_out)
        ff_out_with_residual = ff_out + norm_out

        norm3_out = self.dropout(self.norm(ff_out_with_residual))

        return norm3_out

In [153]:
decoder = DecoderBlock(embedding_dim=512, num_heads=8)
x = torch.randn(2,16,512)
enc_output = torch.randn(2,16,512)

# target mask should be (target_seq_len, target_seq_len)
target_mask = torch.triu(torch.ones((16,16)), diagonal=1)
target_mask = target_mask.unsqueeze(0).unsqueeze(0)
target_mask = target_mask==0
decoder.forward(x,x,x,enc_output, target_mask).shape 

Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])
Norm out shape: torch.Size([2, 16, 512])
Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])


torch.Size([2, 16, 512])

In [162]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, seq_length, embedding_dim, num_layers=2, num_heads=8):
        super(TransformerDecoder, self).__init__()

        self.embedding_layer = Embedding(vocab_size, embedding_dim)
        self.positional_encoder = PositionalEncoding(seq_length, embedding_dim)
        self.layers = nn.ModuleList(
            [DecoderBlock(embedding_dim, num_heads) for i in range(num_layers)]
        )

    def forward(self, x, encoder_output=None, mask=None):
        embedding_output = self.embedding_layer(x)
        positional_encoding = self.positional_encoder(embedding_output)

        out = positional_encoding
        print(f'Positional encoding shape: {out.shape}')
        for layer in self.layers:
            out = layer(out, out, out, encoder_output, mask)
        return out 

In [163]:
td = TransformerDecoder(vocab_size=100, seq_length=10, embedding_dim=512, num_layers=2, num_heads=8)
td

TransformerDecoder(
  (embedding_layer): Embedding(
    (embedding): Embedding(100, 512)
  )
  (positional_encoder): PositionalEncoding()
  (layers): ModuleList(
    (0-1): 2 x DecoderBlock(
      (self_attention): MultiHeadAttention(
        (query_mat): Linear(in_features=64, out_features=64, bias=False)
        (key_mat): Linear(in_features=64, out_features=64, bias=False)
        (value_mat): Linear(in_features=64, out_features=64, bias=False)
        (out): Linear(in_features=512, out_features=512, bias=True)
      )
      (cross_attention): MultiHeadAttention(
        (query_mat): Linear(in_features=64, out_features=64, bias=False)
        (key_mat): Linear(in_features=64, out_features=64, bias=False)
        (value_mat): Linear(in_features=64, out_features=64, bias=False)
        (out): Linear(in_features=512, out_features=512, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in

In [164]:
source = torch.tensor([[3,21,4,6,1,3,5,34,3,1,3,5,1,3,67,8,3,67,4],
                       [3,43,3,5,1,3,67,8,3,67,4,34,23,4,1,3,4,65,2]])

In [166]:
x = torch.randint(0,15,(2,16))
enc_output = torch.randn(2,16,512)

# target mask should be (target_seq_len, target_seq_len)
target_mask = torch.triu(torch.ones((16,16)), diagonal=1)
target_mask = target_mask.unsqueeze(0).unsqueeze(0)
target_mask = target_mask==0

td.forward(x, enc_output, target_mask).shape

Positional encoding shape: torch.Size([2, 16, 512])
Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])
Norm out shape: torch.Size([2, 16, 512])
Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])
Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])
Norm out shape: torch.Size([2, 16, 512])
Before mask, product shape: torch.Size([2, 8, 16, 16])
After mask, product shape: torch.Size([2, 8, 16, 16])
Shape of product is torch.Size([2, 8, 16, 16])


torch.Size([2, 16, 512])

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads=8):
        super(DecoderBlock, self).__init__()

        self.masked_attention = MultiHeadAttention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attention = MultiHeadAttention(embedding_dim, num_heads)
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim)
        )

        self.norm3 = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, encoder_output, mask=None):
        # Masked Self-Attention
        masked_attention_scores = self.masked_attention(x, x, x, mask)
        masked_attention_with_residual = masked_attention_scores + x
        norm1_out = self.dropout(self.norm1(masked_attention_with_residual))

        # Cross-Attention (Query from decoder, Key/Value from encoder output)
        cross_attention_scores = self.cross_attention(encoder_output, norm1_out, encoder_output)
        cross_attention_with_residual = cross_attention_scores + norm1_out
        norm2_out = self.dropout(self.norm2(cross_attention_with_residual))

        # Feed Forward Network
        ff_out = self.feed_forward(norm2_out)
        ff_out_with_residual = ff_out + norm2_out
        norm3_out = self.dropout(self.norm3(ff_out_with_residual))

        return norm3_out

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, seq_length, embedding_dim, num_layers=2, num_heads=8):
        super(TransformerDecoder, self).__init__()

        self.embedding_layer = Embedding(vocab_size, embedding_dim)
        self.positional_encoder = PositionalEncoding(seq_length, embedding_dim)
        self.layers = nn.ModuleList(
            [DecoderBlock(embedding_dim, num_heads) for _ in range(num_layers)]
        )

    def forward(self, x, encoder_output, mask=None):
        embedding_output = self.embedding_layer(x)
        print(f'Embedding output shape is: {embedding_output.shape}')
        positional_encoding = self.positional_encoder(embedding_output)

        out = positional_encoding
        for layer in self.layers:
            out = layer(out, encoder_output, mask)

        return out
