In [114]:
from torch import nn
import torch
import torch.nn.functional as F
from math import sqrt

from encoder import EncoderLayer
from decoder import DecoderLayer
from multi_head_attention import MultiHeadAttention
from positional_encoding import PositionalEncoding
from scaled_dot_attention import attention
from embedding import WordEmbedding

In [115]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, input_dim, num_heads, num_layers = 6, dropout=0.1):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        num_heads: num of heads
        num_layers: number of encoder layers
        """
        super().__init__()

        self.decoder_layers = nn.ModuleList( [ DecoderLayer(embed_dim, input_dim, num_heads) for x in range(num_layers) ] )

    def forward(self, x, encod_out, mask=None):
        for layer in self.decoder_layers:
            x = layer(x, mask)
        
        return x

In [116]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, input_dim, num_heads, num_layers = 6, dropout=0.1):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        num_heads: num of heads
        num_layers: number of encoder layers
        """
        super().__init__()

        self.encoder_layers = nn.ModuleList( [ EncoderLayer(embed_dim, input_dim, num_heads) for x in range(num_layers) ] )

    def forward(self, x, mask=None):
        for layer in self.encoder_layers:
            x = layer(x, mask)
        
        return x

In [117]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, input_dim, num_heads, num_layers_encod = 6, num_layers_decod = 6, dropout = 0.1):
        super().__init__()

        self.embedding1 = WordEmbedding(vocab_size, embed_dim)
        self.embedding2 = WordEmbedding(vocab_size, embed_dim)

        #self.encoding1 = PositionalEncoding(embed_dim, input_dim)
        #self.encoding2 = PositionalEncoding(embed_dim, input_dim)
        
        self.encoder = Encoder(embed_dim, input_dim, num_heads, num_layers_encod, dropout)
        self.decoder = Decoder(embed_dim, input_dim, num_heads, num_layers_decod, dropout)

        self.linear = nn.Linear(embed_dim, embed_dim)
        self.soft = nn.Softmax(dim=-1)  # which dim to apply it along??
   
    def forward(self, input, output,  mask=None):
        # input goes through encoder
        input = self.embedding1(input)
        input = self.encoding1(input)

        encod_out = self.encoder(input, mask)

        # output
        output = self.embedding2(output)
        output = self.encoding2(output)

        # output and input combined into decoder
        decod_out = self.decoder(output, encod_out, mask)

        # softmax and linear layers
        out = self.linear(decod_out)
        out = self.soft(out)

        return out

In [118]:
# TEST

# TESTING
# TESTING
embed_dim = 3
num_heads = 1

x = torch.tensor([[0, 10, 0]], dtype=torch.float32)
input_dim = 3

encoder = Encoder(embed_dim, input_dim, num_heads, num_layers=6)
encod_out = encoder(x)

print(encod_out)

model = Transformer(1, embed_dim, input_dim, num_heads)
output = model.forward(x, x)
print(output)

tensor([[[-1.1853,  1.2607, -0.0753]]], grad_fn=<NativeLayerNormBackward0>)
tensor([[[0.2447, 0.4535, 0.3018]]], grad_fn=<SoftmaxBackward0>)
