In [2]:
import torch
import torch.nn as nn
import pandas

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.keys = nn.Linear(self.embed_dim, self.embed_dim, bias = False)
        self.queries = nn.Linear(self.embed_dim, self.embed_dim, bias = False)
        self.values = nn.Linear(self.embed_dim, self.embed_dim, bias = False)
        self.fc_out = nn.Linear(self.num_heads * self.head_dim, self.embed_dim)

    def forward(self, values, keys, queries, mask = None):
        batch_size = keys[0]
        key_len, query_len, value_len = keys[1], queries[1], values[1]

        keys = self.keys(keys)
        queries = self.queries(queries)
        values = self.values(values)
        
        keys = keys.reshape(batch_size, key_len, self.num_heads, self.head_dim)
        queries = queries.reshape(batch_size, query_len, self.num_heads, self.head_dim)
        values = values.reshape(batch_size, value_len, self.num_heads, self.head_dim)

        energy = torch.einsum("nqhd, nkhd -> nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-1e20'))
            
        attention = torch.softmax(energy / (self.embed_dim ** 0.5), dim = 3)

        att_out = torch.einsum("nhql, nlhd -> nqhd", [attention, values]).reshape(batch_size, query_len, self.num_head, self.head_dim)
        out = self.fc_out(att_out)
        return out

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion * embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
        

    def forward (self, values, keys, queries, mask = None):
        #return self.norm2(norm1(x + self.attention(x)) + self.ff(self.norm1(x + self.attention(x))))
        attention = self.attention(values, keys, queries, mask)
        sublayer1 = self.dropout(self.norm1(attention + queries))
        return self.dropout(self.norm2(self.ff(sublayer1) + sublayer1))

In [5]:
class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embed_dim, num_layers, num_heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embed_dim = embed_dim
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)

        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout, forward_expansion)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        batch_size, seq_len = x.shape          # x dim: list of list of tokens
        positions = torch.arange(0, seq_len).expand(batch_size, seq_len).to(self.device)
        embedded_input = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(embedded_input, embedded_input, embedded_input, mask)
        return out

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.embed_dim = embed_dim
        self.device = device
        self.attention = SelfAttention(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)
        self.transformer_block = TransformerBlock(embed_dim, num_heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, tgt_mask):
        attention = self.attention(x, x, x, tgt_mask)
        query = self.dropout(self.norm(attention + x))
        out = transformer_block(value, key, query, src_mask)
        return out

In [7]:
class Decoder(nn.Module):
    def __init__(self, tgt_vocab_size, embed_dim, num_layers, num_heads, forward_expansion, dropout, device, max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_length, embed_dim)
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, forward_expansion, dropout, device)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        batch_size, seq_len = x.shape
        positions = torch.arange(0, seq_len).expand(batch_size, seq_len).to(self.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, tgt_mask)
            out = self.fc_out(x)
        

In [None]:
 class Transformer(nn.Module):
     def __init__(self, src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx, embed_size = 256, num_layers = 6, forward_expansion = 4, num_heads = 8, dropout = 0, device = 'cuda', max_length = 100):
         super(Transformer, self).__init__()
         self.encoder = Encoder(src_vocab_size, embed_dim, num_layers, num_heads, device, forward_expansion, dropout, max_length)
         self.decoder = Decoder(tgt_vocab_size, embed_dim, num_layers, num_heads, forward_expansion, dropout, device, max_length)
         self.src_pad_idx = src_pad_idx
         self.tgt_pad_idx = tgt_pad_idx
         self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_tgt_mask(self, tgt):
        batch_size, tgt_len = tgt.shape
        tgt_mask = torch.tril(torch.ones((tgt_len, tgt_len))).expand(batch_size, 1, tgt_len, tgt_len)
        return tgt_mask.to(self.device)

    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(tgt, enc_src, src_mask, tgt_mask)
        return out