In [2]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
import numpy as np
%config Completer.use_jedi = False

## Positional Encoding

In [3]:
def get_positional_encodings(embeds, input):
    pe = torch.zeros((embeds.shape[0], embeds.shape[1], embeds.shape[2]), dtype = torch.float)

    for pos in range(input.shape[-1]):
        v = pe[0, pos]
        for i in range(v.shape[0]):
            if i % 2 ==0:
                pwr = (2 * i) / dim_in
                v[i] = np.sin(pos / (1e4 ** pwr))
            else:
                pwr = (2 * i) / dim_in
                v[i] = np.cos(pos / (1e4 ** pwr))

        pe[0, pos] = v

    return pe

## Self-Attention

In [4]:
# torch.manual_seed(1)

class AttentionHead(nn.Module):
    def __init__(self, dim_in, attn_dim):
        super().__init__()
        
        self.wq = nn.Linear(dim_in, attn_dim)
        self.wk = nn.Linear(dim_in, attn_dim)
        self.wv = nn.Linear(dim_in, attn_dim)
        
    def forward(self, input, q, k, v, mask_inputs):

        if input == None:
            query = q
            key = k
            value = v
        else:
            query = self.wq(input)
            key = self.wk(input)
            value = self.wv(input)
        
        score = torch.bmm(query, torch.transpose(key, 1, 2))
        scale = query.shape[-1] ** 0.5
        score = score / scale
        
        if mask_inputs:
            masked_attention = torch.zeros(score.shape[0], score.shape[1], score.shape[2])
            
            for i in range(masked_attention.shape[1]):
                masked_attention[0, i, (i+1):] = -float('inf')

            score += masked_attention
            
        softmax = F.softmax(score, dim = -1)
        return torch.bmm(softmax, value)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, dim_in, attn_dim):
        super().__init__()
        
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, attn_dim) for _ in range(num_heads)]
        )
        
        self.linear = nn.Linear(num_heads*attn_dim, dim_in)
        
    def forward(self, input, q, k, v, mask_inputs):
        
        if input == None:
            lin = nn.Linear(q.shape[-1] * len(self.heads), q.shape[-1])
            return lin(
                torch.cat([h(input, q, k, v, mask_inputs) for h in self.heads], dim = -1)
            )
        
        else:
            return self.linear(
                torch.cat([h(input, q, k, v, mask_inputs) for h in self.heads], dim = -1)
            )

## Risidual & Feed Forward

In [5]:
class Risidual(nn.Module):
    def __init__(self, dropout, dim_in):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim_in)
        
    def forward(self, prev_layer_input, prev_layer_output):
        z = prev_layer_input + prev_layer_output
        return self.norm(z[-1] + self.dropout(prev_layer_output))
    

class FeedForward(nn.Module):
    def __init__(self, dim_in, dim_feedforward):
        super().__init__()
        
        self.lin1 = nn.Linear(dim_in, dim_feedforward)
        self.relu = nn.ReLU()
        self.lin2 = nn.Linear(dim_feedforward, dim_in)
    
    def forward(self, risidual_output):
        encoder_output = self.lin1(risidual_output)
        encoder_output = self.relu(encoder_output)
        encoder_output = self.lin2(encoder_output)
        
        return encoder_output

## Encoder

In [6]:
# torch.manual_seed(1)

class TransformerEncoderLayer(nn.Module):
    def __init__(self,
                dim_in,
                attn_dim,
                num_heads,
                dim_feedforward,
                dropout):
        
        super().__init__()
        
        self.mha = MultiHeadAttention(num_heads,dim_in,attn_dim)
        
        self.risidual_mha = Risidual(dropout,dim_in)
        
        self.feed_forward = FeedForward(dim_in, dim_feedforward)
        
        self.risidual_ff = Risidual(dropout,dim_in)
        
    def forward(self, input):
        mha = self.mha(input, None, None, None, False)
        risidual_mha = self.risidual_mha(input, mha)
        feed_forward = self.feed_forward(risidual_mha)
        risidual_ff = self.risidual_ff(risidual_mha, feed_forward)
        
        return risidual_ff

class TransformerEncoder(nn.Module):
    def __init__(self,
                 num_layers,
                 dim_in,
                 seq_len,
                 attn_dim,
                 num_heads,
                 vocab_size,
                 dim_feedforward,
                 dropout):
        
        super().__init__()
    
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(dim_in,
                                    attn_dim,
                                    num_heads,
                                    dim_feedforward,
                                    dropout)
            for _ in range(num_layers)
        ])
        
        self.embedding_layer = nn.Embedding(vocab_size, dim_in)
    
    def forward(self, input):
        embeds = self.embedding_layer(input)
        embeds += get_positional_encodings(embeds, input)
        
        for encoder_layer in self.encoder_layers:
            embeds = encoder_layer(embeds)
            
        return embeds

## Decoder

In [7]:
# torch.manual_seed(1)

class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim_in, attn_dim, num_heads, dim_feedforward, dropout):
        super().__init__()
        
        self.mha = MultiHeadAttention(num_heads,dim_in,attn_dim)
        self.risidual_mha = Risidual(dropout, dim_in)
        self.risidual_mha2 = Risidual(dropout, dim_in)
        self.feed_forward = FeedForward(dim_in,dim_feedforward)
        self.risidual_ff = Risidual(dropout, dim_in)
        
    def forward(self, input, encoder_output):
        mha = self.mha(input, None, None, None, True)
        risidual_mha = self.risidual_mha(input, mha) # this is the query vector for the next MHAL
        mha2 = self.mha(None, risidual_mha, encoder_output, encoder_output, False)
        risidual_mha2 = self.risidual_mha2(risidual_mha, mha2)
        feed_forward = self.feed_forward(risidual_mha2)
        risidual_ff = self.risidual_ff(risidual_mha2, feed_forward)
        
        return risidual_ff
    
class TransformerDecoder(nn.Module):
    def __init__(self,
                 num_layers, 
                 dim_in, 
                 seq_len, 
                 attn_dim, 
                 num_heads, 
                 vocab_size, 
                 dim_feedforward,
                 dropout):
        super().__init__()
        
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(dim_in,
                                    attn_dim,
                                    num_heads, 
                                    dim_feedforward, 
                                    dropout)
            
            for _ in range(num_layers)
        ])
        
        self.embedding_layer = nn.Embedding(vocab_size, dim_in)
        
    def forward(self, input, encoder_output):
        embeds = self.embedding_layer(input)
        embeds += get_positional_encodings(embeds, input)
        
        for decoder_layer in self.decoder_layers:
            embeds = decoder_layer(embeds, encoder_output)
            
        return embeds

In [8]:
# torch.manual_seed(1)

class Transformer(nn.Module):
    def __init__(self,
                dim_in,
                seq_len,
                attn_dim,
                num_heads,
                vocab_size,
                dim_feedforward,
                dropout,
                num_layers,
                target_seq_len,
                target_vocab_size):
        super().__init__()
        
        self.encoder = TransformerEncoder(num_layers,
                                         dim_in,
                                         seq_len,
                                         attn_dim,
                                         num_heads,
                                         vocab_size,
                                         dim_feedforward,
                                         dropout)
        
        self.decoder = TransformerDecoder(num_layers,
                                         dim_in,
                                         target_seq_len,
                                         attn_dim,
                                         num_heads,
                                         target_vocab_size,
                                         dim_feedforward,
                                         dropout)
        self.target_vocab_size = target_vocab_size
        
    def forward(self, src, trgt):
        encoder_output = self.encoder(src)
        decoder_output = self.decoder(trgt, encoder_output)
        
        decoder_output = decoder_output.view(1, -1)
        lin = nn.Linear(decoder_output.shape[-1], self.target_vocab_size)
        output = F.softmax(lin(decoder_output), dim = -1)
        
        return output

In [14]:
dim_in = 5
seq_len = 4
batch_size = 3
attn_dim = 64
num_heads = 8
vocab_size = 10
dim_feedforward = 2048
dropout = 0.1
num_layers = 6
target_seq_len = 5
target_vocab_size = 12

src = torch.randint(0, vocab_size, (batch_size, seq_len))
target = torch.randint(0, target_vocab_size, (batch_size, target_seq_len))

transformer = Transformer(dim_in,
                          seq_len,
                          attn_dim,
                          num_heads,
                          vocab_size,
                          dim_feedforward,
                          dropout,
                          num_layers,
                          target_seq_len,
                          target_vocab_size)

transformer_output = transformer(src, target)
print(transformer_output)
print(transformer_output.shape)

tensor([[0.1386, 0.0569, 0.0419, 0.0828, 0.0496, 0.0664, 0.2338, 0.0395, 0.0853,
         0.0541, 0.0708, 0.0801]], grad_fn=<SoftmaxBackward>)
torch.Size([1, 12])


In [12]:
print(src)
print(src.shape)

tensor([[4, 4, 7, 2],
        [1, 5, 0, 5],
        [7, 2, 2, 9]])
torch.Size([3, 4])
