## Attention Is All You Need

Original paper: https://arxiv.org/pdf/1706.03762.pdf

Implementing a transformer kind of from scratch using numpy and torch

In [271]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import math

In [None]:
"""
Need to implement
- Scaled dot-product attention
- Multi-head attention
- Positional encoding
- Layer normalization
- Position-wise feed forward
- Encoder layer (combination of some of the above)
- Encoder (stack of encoder layers)
- Multi-head cross attention
- Decoder layer
- Decoder
- Transformer (combining encoder and decoder)
"""

In [272]:
#mask = torch.triu(torch.ones_like(x) * float('-inf'), diagonal=1)

def scaled_dot_product_attention(q, k, v, mask=None):
    numerator = q @ torch.transpose(k, -2, -1) # May have to fix this transpose
    if mask is not None:
        numerator = numerator + mask
    denominator = math.sqrt(k.shape[-1])
    sm = F.softmax((numerator/denominator), dim=-1, dtype=torch.float32)
    result = sm @ v
    return result

In [273]:
# Change this to take q, k, v as input like the original paper
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, embed_dim):
        super().__init__()
        self.heads = heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // heads # Embed dim must be divisible by heads
        self.qkv_linear = nn.Linear(self.embed_dim, self.embed_dim * 3)
        
    def forward(self, x, mask=None):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_linear(x)
        qkv = qkv.reshape(batch_size, seq_length, self.heads, self.head_dim * 3)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        x = scaled_dot_product_attention(q,k,v)
        x = x.permute(0, 2, 1, 3).reshape(batch_size, seq_length, self.embed_dim)
        return x

In [274]:
test = torch.randn((30,50,512))

mh = MultiHeadAttention(8, 512)
res = mh(test)
print(res.shape)
# Check if tensors equal within threshold
# torch.all(torch.lt(torch.abs(torch.add(a, -b)), 1e-10))

torch.Size([30, 50, 512])
