In [1]:
# ATTENTION

from math import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):

    def __init__(self, embed_dim=512, heads=8):

        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim  # 512 by default
        self.heads = heads  # 8 heads by default
        self.head = int(self.embed_dim / self.heads)  # 512 / 8 = 64 by default
        # note: The embedding dimension must be divided by the number of heads

        # query, value, key: (64x64)
        self.query = nn.Linear(self.head, self.head, bias=False)  # the Query metrix
        self.value = nn.Linear(self.head, self.head, bias=False)  # the Value metrix
        self.key = nn.Linear(self.head, self.head, bias=False)  # the Key metrix

        # fully connected layer: 8*64x512 or 512x512
        self.fc_out = nn.Linear(self.head * self.heads, embed_dim)

    def forward(self, key, query, value, mask=None):
        # Input of size: batch_size x sequence length x embedding dims
        batch_size = key.size(0)
        k_len, q_len, v_len = key.size(1), query.size(1), value.size(1)

        # reshape from (batch_size x seq_len x embed_size) -> (batch_size x seq_len x heads x head)
        # example: from (32x10x512) -> (32x10x8x64)
        key = key.reshape(batch_size, k_len, self.heads, self.head)
        query = query.reshape(batch_size, q_len, self.heads, self.head)
        value = value.reshape(batch_size, v_len, self.heads, self.head)

        key = self.key(key)  # (32x10x8x64)
        query = self.query(query)  # (32x10x8x64)
        value = self.value(value)  # (32x10x8x64)

        ############### query x key ###############

        # query shape: batch_size x q_len, heads, head, e.g: (32x10x8x64)
        # key shape: batch_size x v_len, heads, head, e.g: (32x10x8x64)
        # product shape should be: batch_size, heads, q_len, v_len, e.g: (32x8x10x10)
        product = torch.einsum("bqhd,bkhd->bhqk", [query, key])

        # if mask (in decoder)
        if mask is not None:
            product = product.masked_fill(mask == 0, float("-1e20"))

        product = product / sqrt(self.head)

        scores = F.softmax(product, dim=-1)

        ############### scores x value ###############

        # scores shape: batch_size, heads, q_len, v_len, e.g: (32x8x10x10)
        # value shape: batch_size, v_len, heads, head, e.g: (32x10x8x64)
        # output: batch_size, heads, v_len, head, e.g: (32x10x512)

        output = torch.einsum("nhql,nlhd->nqhd", [scores, value]).reshape(
            batch_size, q_len, self.heads * self.head
        )

        output = self.fc_out(output)  # (32x10x512) -> (32x10x512)

        return output

In [2]:
# EMBEDDINGS

from math import sin, cos, sqrt, log
import torch
import torch.nn as nn


class Embedding(nn.Module):

    def __init__(self, vocab_size, embed_dim):

        super(Embedding, self).__init__()
        self.embed_dim = embed_dim
        self.embed = nn.Embedding(vocab_size, embed_dim)

    def forward(self, x):

        output = self.embed(x) * sqrt(self.embed_dim)
        # print(f"Embedding shape: {output.shape}")
        return output


class PositionalEncoding(nn.Module):

    def __init__(self, embed_dim, max_seq_len=5000, dropout=0.1):

        super(PositionalEncoding, self).__init__()
        self.embed_dim = embed_dim
        self.dropout = nn.Dropout(dropout)

        positional_encoding = torch.zeros(max_seq_len, self.embed_dim)
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2) * -(log(10000.0) / embed_dim)
        )
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        pe = positional_encoding.unsqueeze(0)

        # we use register_buffer to save the "pe" parameter to the state_dict
        self.register_buffer('pe', pe)

    def pe_sin(self, position, i):
        return sin(position / (10000 ** (2 * i) / self.embed_dim))

    def pe_cos(self, position, i):
        return cos(position / (10000 ** (2 * i) / self.embed_dim))

    def forward(self, x):
        # print(x.shape)
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [3]:
# UTILS

import copy
import torch.nn as nn

def replicate(block, N=6) -> nn.ModuleList:

    block_stack = nn.ModuleList([copy.deepcopy(block) for _ in range(N)])
    return block_stack

In [4]:
# ENCODER

import torch.nn as nn

def replicate(block, N=6) -> nn.ModuleList:

    block_stack = nn.ModuleList([copy.deepcopy(block) for _ in range(N)])
    return block_stack

class TransformerBlock(nn.Module):

    def __init__(self,
                 embed_dim=512,
                 heads=8,
                 expansion_factor=4,
                 dropout=0.2
                 ):

        super(TransformerBlock, self).__init__()

        self.attention = MultiHeadAttention(embed_dim, heads)  # the multi-head attention
        self.norm = nn.LayerNorm(embed_dim)  # the normalization layer

        # the FeedForward layer
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, expansion_factor * embed_dim),  # e.g: 512x(4*512) -> (512, 2048)
            nn.ReLU(),  # ReLU activation function
            nn.Linear(embed_dim * expansion_factor, embed_dim),  # e.g: 4*512)x512 -> (2048, 512)
        )

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, key, query, value, mask=None):
        #################### Multi-Head Attention ####################
        # first, pass the key, query and value through the multi head attention layer
        attention_out = self.attention(key, query, value, mask)  # e.g.: 32x10x512

        # then add the residual connection
        attention_out = attention_out + value  # e.g.: 32x10x512

        # after that we normalize and use dropout
        attention_norm = self.dropout(self.norm(attention_out))  # e.g.: 32x10x512
        # print(attention_norm.shape)

        #################### Feed-Forwar Network ####################
        fc_out = self.feed_forward(attention_norm)  # e.g.:32x10x512 -> #32x10x2048 -> 32x10x512

        # Residual connection
        fc_out = fc_out + attention_norm  # e.g.: 32x10x512

        # dropout + normalization
        fc_norm = self.dropout(self.norm(fc_out))  # e.g.: 32x10x512

        return fc_norm


class Encoder(nn.Module):

    def __init__(self,
                 seq_len,
                 vocab_size,
                 embed_dim=512,
                 num_blocks=6,
                 expansion_factor=4,
                 heads=8,
                 dropout=0.2
                 ):

        super(Encoder, self).__init__()

        # define the embedding: (vocabulary size x embedding dimension)
        self.embedding = Embedding(vocab_size, embed_dim)

        # define the positional encoding: (embedding dimension x sequence length)
        self.positional_encoder = PositionalEncoding(embed_dim, seq_len)

        # define the set of blocks
        # so we will have 'num_blocks' stacked on top of each other
        self.blocks = replicate(TransformerBlock(embed_dim, heads, expansion_factor, dropout), num_blocks)

    def forward(self, x):
        out = self.positional_encoder(self.embedding(x))
        for block in self.blocks:
            out = block(out, out, out)

        # output shape: batch_size x seq_len x embed_size, e.g.: 32x10x512
        return out

In [5]:
# DECODER

import torch.nn as nn
import torch.nn.functional as F



class DecoderBlock(nn.Module):

    def __init__(self,
                 embed_dim=512,
                 heads=8,
                 expansion_factor=4,
                 dropout=0.2
                 ):

        super(DecoderBlock, self).__init__()

        # First define the Decoder Multi-head attention
        self.attention = MultiHeadAttention(embed_dim, heads)
        # normalization
        self.norm = nn.LayerNorm(embed_dim)
        # Dropout to avoid overfitting
        self.dropout = nn.Dropout(dropout)
        # finally th transformerBlock
        self.transformerBlock = TransformerBlock(embed_dim, heads, expansion_factor, dropout)

    def forward(self, key, query, x, mask):
        # pass the inputs to the decoder multi-head attention
        decoder_attention = self.attention(x, x, x, mask)
        # residual connection + normalization
        value = self.dropout(self.norm(decoder_attention + x))
        # finally the transformerBlock (multi-head attention -> residual + norm -> feed forward -> residual + norm)
        decoder_attention_output = self.transformerBlock(key, query, value)

        return decoder_attention_output


class Decoder(nn.Module):

    def __init__(self,
                 target_vocab_size,
                 seq_len,
                 embed_dim=512,
                 num_blocks=6,
                 expansion_factor=4,
                 heads=8,
                 dropout=0.2
                 ):

        super(Decoder, self).__init__()

        # define the embedding
        self.embedding = nn.Embedding(target_vocab_size, embed_dim)
        # the positional embedding
        self.positional_encoder = PositionalEncoding(embed_dim, seq_len)

        # define the set of decoders
        self.blocks = replicate(DecoderBlock(embed_dim, heads, expansion_factor, dropout), num_blocks)
        # dropout for overfitting
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, mask):
        x = self.dropout(self.positional_encoder(self.embedding(x)))  # 32x10x512

        for block in self.blocks:
            x = block(encoder_output, x, encoder_output, mask)

        return x

In [6]:
# TRANSFORMER

import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer(nn.Module):

    def __init__(self,
                 embed_dim,
                 src_vocab_size,
                 target_vocab_size,
                 seq_len,
                 num_blocks=6,
                 expansion_factor=4,
                 heads=8,
                 dropout=0.2):
        super(Transformer, self).__init__()
        self.target_vocab_size = target_vocab_size

        self.encoder = Encoder(seq_len=seq_len,
                               vocab_size=src_vocab_size,
                               embed_dim=embed_dim,
                               num_blocks=num_blocks,
                               expansion_factor=expansion_factor,
                               heads=heads,
                               dropout=dropout)

        self.decoder = Decoder(target_vocab_size=target_vocab_size,
                               seq_len=seq_len,
                               embed_dim=embed_dim,
                               num_blocks=num_blocks,
                               expansion_factor=expansion_factor,
                               heads=heads,
                               dropout=dropout)

        self.fc_out = nn.Linear(embed_dim, target_vocab_size)

    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        # returns the lower triangular part of matrix filled with ones
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        )
        return trg_mask

    def forward(self, source, target):
        trg_mask = self.make_trg_mask(target)
        enc_out = self.encoder(source)
        outputs = self.decoder(target, enc_out, trg_mask)
        output = F.softmax(self.fc_out(outputs), dim=-1)
        return output

In [7]:
# TEST


import torch

src_vocab_size = 11
target_vocab_size = 11
num_blocks = 6
seq_len = 512

# let 0 be sos token and 1 be eos token
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1],
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1],
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape, target.shape)
model = Transformer(embed_dim=512,
                    src_vocab_size=src_vocab_size,
                    target_vocab_size=target_vocab_size,
                    seq_len=seq_len,
                    num_blocks=num_blocks,
                    expansion_factor=4,
                    heads=8)

print(model)
out = model(src, target)
print(f"Output Shape: {out.shape}")

torch.Size([2, 12]) torch.Size([2, 12])
Transformer(
  (encoder): Encoder(
    (embedding): Embedding(
      (embed): Embedding(11, 512)
    )
    (positional_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadAttention(
          (query): Linear(in_features=64, out_features=64, bias=False)
          (value): Linear(in_features=64, out_features=64, bias=False)
          (key): Linear(in_features=64, out_features=64, bias=False)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (1): TransformerBlock(
      

In [8]:
src.shape

torch.Size([2, 12])

In [17]:
a = torch.randint(3, 10, (1, 5))

In [18]:
a

tensor([[5, 4, 4, 3, 4]])

In [19]:
a.shape

torch.Size([1, 5])

In [23]:
model(a, a).shape