In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import sys
sys.path.append("..")

from transformer.modules import LayerNorm, clone_layer, Sublayer

> The decoder is also composed of a stack of N = 6 identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization. We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with the fact that the output embeddings are offset by one-position, ensures that the predictions for the position _i_ can depend only on the known outputs at positions less than _i_.

In [2]:
class Decoder(nn.Module):
    def __init__(self, layer: nn.Module, N: int):
        super(Decoder, self).__init__()
        self.layers = clone_layer(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x: torch.Tensor,
                memory: torch.Tensor,
                src_mask: torch.Tensor,
                tgt_mask: torch.Tensor):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)        

In [3]:
class DecoderLayer(nn.Module):
    def __init__(self, size,
                 masked_self_attn: nn.Module,
                 enc_attn: nn.Module,
                 feed_forward: nn.Module,
                 dropout_prob: float):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.masked_self_attn = masked_self_attn
        self.enc_attn = enc_attn
        self.feed_forward = feed_forward
        self.sublayer = clone_layer(
            Sublayer(size, dropout_prob=dropout_prob), 3)

    def forward(self, x: torch.Tensor,
                memory: torch.Tensor,
                src_mask: torch.Tensor,
                tgt_mask: torch.Tensor) -> torch.Tensor:
        m = memory
        x = self.sublayer[0](x, lambda x: self.masked_self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.enc_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

<h2 align="center">Encoder-Decoder Architecture</h2>
<div align="center">
    <img src="images/decoder.png" alt="Encoder decoder network" />
</div>

<p align="left">
    Still don't have: 
    <ul>
        <li>Positional Encoding / Positional Feed forward</li>
        <li>Multi-head attention (self/masked etc)</li>
        <li>
            Masking functions
            <ol>
                <li>Encoder masks - To ensure we're not attending to paddings. Or maybe we just pack the inputs such that we don't need this.</li>
                <li>Decoder masks - To ensure we're not looking ahead of current-1 timestep in the decoder during training.</li>
            </ol>
        </li>
        <li>Embeddings</li>
    </ul>
</p>