In [None]:
import torch
from torch import nn, Tensor

In [172]:
# Util functions

def param(*dims) -> nn.Parameter:
    data = torch.empty(dims)
    nn.init.xavier_uniform_(data)
    return nn.Parameter(data)

## Encoder

In [None]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model: int, num_heads: int, d_ff: int) -> None:
        if d_model % num_heads != 0:
            raise ValueError("Number d_model must be divisible by num_heads")
        super().__init__()
        self.h = num_heads
        self.d = d_model // num_heads
        self.root_d = torch.sqrt(torch.tensor(self.d))

        # trainable projections for all "heads"
        # all in one param matrix.
        self.wq = param(d_model, d_model)
        self.wk = param(d_model, d_model)
        self.wv = param(d_model, d_model)
        # wo is needed for residual. Also in general it aligns the dimensions
        # but in original paper dk = dv = d_model / h so it's not needed for that.
        self.wo = param(d_model, d_model)
        # FFN parametes
        self.w1 = param(d_model, d_ff)
        self.b1 = torch.zeros(d_ff)
        self.w2 = param(d_ff, d_model)
        self.b2 = torch.zeros(d_model)
        # layer norms
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def __call__(self, x: Tensor) -> Tensor:
        mha = self.multi_head_attention(x, x, x)
        x = self.ln1(x + mha)
        ffn = torch.relu(x @ self.w1 + self.b1) @ self.w2 + self.b2
        return self.ln2(x + ffn)
    
    def multi_head_attention(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        """
        Input shapes are (batch_size, seq_len, d_model)
        """
        # first project then split into "heads"
        batch_size, seq_len, d_model = q.shape
        reshape = lambda x: x.reshape(batch_size, seq_len, self.h, self.d).permute(0, 2, 1, 3)
        q_proj = reshape(q @ self.wq)
        k_proj = reshape(k @ self.wk)
        v_proj = reshape(v @ self.wv)  # v is same dim as q, k
        x = self.attention(q_proj, k_proj, v_proj)
        # now we need to "concat" all the heads so we get d_model at the end
        x = x.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        # then do one more linear projection
        return x @ self.wo
    
    def attention(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        """
        Input shapes for q and k are:
            (batch_size, seq_len, num_heads, d_model // num_heads)
        And for v it is:
            (batch_size, num_heads, d_v) where d_v is d_model // num_heads in encoder.

        It does all "heads" in one go, i.e. broadcast matrix multiplication.
        Output is the shape
            (batch_size, seq_len, num_heads, d_v)
        """
        return torch.softmax(q @ k.permute(0, 1, 3, 2) / self.root_d, dim=3) @ v

In [208]:
class Encoder(nn.Module):

    def __init__(
            self,
            n_layers: int, 
            emb_dim: int,  # token embedding dim. Same as d_model
            num_heads: int,
            d_ff: int  # hidden dim of FFN
            ) -> None:
        super().__init__()
        layers = [EncoderLayer(emb_dim, num_heads, d_ff) for _ in range(n_layers)]
        self.model = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        """
        x has shape (batch_size, seq_len, emb_dim)
        """
        return self.model(x)


In [209]:
emb_dim = 512
batch_size = 7
seq_len = 13
num_heads = 8
dff = 2048
n_layers = 11

x = torch.rand(batch_size, seq_len, emb_dim)

encoder = Encoder(n_layers, emb_dim, num_heads, dff)

encoder_output = encoder(x)
encoder_output.shape

torch.Size([7, 13, 512])

## Decoder

In [217]:
class DecoderLayer(nn.Module):
    def __init__(
            self,
            d_model: int,
            num_heads: int,
            d_ff: int,
        ) -> None:
        if d_model % num_heads != 0:
            raise ValueError("Number d_model must be divisible by num_heads")
        super().__init__()
        self.h = num_heads
        self.d = d_model // num_heads
        self.root_d = torch.sqrt(torch.tensor(self.d))
        # attention params
            # bottom attention (masked one)
        self.wq1 = param(d_model, d_model)
        self.wk1 = param(d_model, d_model)
        self.wv1 = param(d_model, d_model)
        self.wo1 = param(d_model, d_model)
            # top attention (with encoder output, no mask)
        self.wq2 = param(d_model, d_model)
        self.wk2 = param(d_model, d_model)
        self.wv2 = param(d_model, d_model)
        self.wo2 = param(d_model, d_model)
        # FFN parametes
        self.w1 = param(d_model, d_ff)
        self.b1 = torch.zeros(d_ff)
        self.w2 = param(d_ff, d_model)
        self.b2 = torch.zeros(d_model)
        # layer norms
        self.ln1 = nn.LayerNorm(d_model)  # bottom attention
        self.ln2 = nn.LayerNorm(d_model)  # top attention
        self.ln_ffn = nn.LayerNorm(d_model)
    
    def forward(self, encoder_output: Tensor, decoder_output: Tensor) -> Tensor:
        # bottom blocks
        mha1 = self.multi_head_attention(
                decoder_output,
                decoder_output,
                decoder_output,
                self.wq1,
                self.wk1,
                self.wv1,
                self.wo1,
                apply_mask=True
            )
        x = self.ln1(mha1 + decoder_output)
        
        # now add in the encoder signal
        mha2 = self.multi_head_attention(
                x,
                encoder_output,
                encoder_output,
                self.wq2,
                self.wk2,
                self.wv2,
                self.wo2,
                apply_mask=False
            )
        x = self.ln2(mha2 + x)
        
        # now feed forward network
        ffn = torch.relu(x @ self.w1 + self.b1) @ self.w2 + self.b2
        return self.ln_ffn(ffn + x)


    def multi_head_attention(
            self,
            q: Tensor,
            k: Tensor,
            v: Tensor,
            wq: Tensor,
            wk: Tensor,
            wv: Tensor,
            wo: Tensor,
            apply_mask: bool,
        ) -> Tensor:
        """
        Input shapes are (batch_size, seq_len, d_model)
        """
        batch_size, q_seq_len, d_model = q.shape
        k_seq_len = k.shape[1]
        # reshape function will give us (batch, h, seq, d_model/h)
        # so we can do broadcast multiplication with that
        reshape = lambda x, seq_len: x.reshape(batch_size, seq_len, self.h, self.d).permute(0, 2, 1, 3)
        q_proj = reshape(q @ wq, q_seq_len)
        k_proj = reshape(k @ wk, k_seq_len)
        v_proj = reshape(v @ wv, k_seq_len)
        x = self.attention(q_proj, k_proj, v_proj, apply_mask)
        # now we need to permute back to (batch, seq, h, d_model/h) 
        # then "concat" all the heads so we get d_model at the end
        x = x.permute(0, 2, 1, 3).reshape(batch_size, q_seq_len, d_model)
        # then do one more linear projection
        return x @ wo

    def attention(self, q: Tensor, k: Tensor, v: Tensor, apply_mask: bool) -> Tensor:
        """
        If param apply_max == True, it will add -inf to the upper triangular matrix of Q*K^T
        before applying softmax (i.e. those values will be essentially disregarded in the softmax).
        Input shapes for q and k are:
            (batch_size, seq_len, num_heads, d_model // num_heads)
        And for v it is:
            (batch_size, num_heads, d_v) where d_v is d_model // num_heads in encoder.

        It does all "heads" in one go, i.e. broadcast matrix multiplication.
        Output is the shape
            (batch_size, seq_len, num_heads, d_v)
        """
        qkt = q @ k.permute(0, 1, 3, 2)
        if apply_mask:
            qkt += self.get_mask(q.shape[2])
        return torch.softmax(qkt / self.root_d, dim=3) @ v
    
    def get_mask(self, seq_len: int) -> Tensor:
        mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
        return torch.where(mask, 0.0, float('-inf'))

In [226]:
class Decoder(nn.Module):
    def __init__(
            self,
            n_layers: int, 
            emb_dim: int,  # token embedding dim. Same as d_model
            num_heads: int,
            d_ff: int,  # hidden dim of FFN
            vocab_size: int
            ) -> None:
        super().__init__()
        self.layers = nn.ModuleList(DecoderLayer(emb_dim, num_heads, d_ff) for _ in range(n_layers))
        self.final_projection = nn.Sequential(
            nn.Linear(emb_dim, vocab_size),
            nn.Softmax(-1),            
        )

    def __call__(self, encoder_output: Tensor, decoder_output: Tensor) -> Tensor:
        x = decoder_output
        for layer in self.layers:
            x = layer(encoder_output, x)
        return self.final_projection(x)
        

In [227]:
y = torch.rand(batch_size, 1, emb_dim)

vocab_size = 1000
decoder = Decoder(n_layers, emb_dim, num_heads, dff, vocab_size)

decoder_output = decoder(encoder_output, y)
decoder_output.shape

torch.Size([7, 1, 1000])