In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.distributions import Categorical

import math
import copy

from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset

from einops import rearrange

# from https://medium.com/towards-data-science/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [14]:
mha = MultiHeadAttention(36, 6).cuda()

x = torch.randn(1, 10, 36).cuda()
y = torch.randn(1, 5, 36).cuda()

In [15]:
a = mha(x, x, x)

In [16]:
a.size()

torch.Size([1, 10, 36])

In [17]:
b = mha(x, y, y)

In [18]:
b.size()

torch.Size([1, 10, 36])

In [19]:
Q = mha.split_heads(mha.W_q(x))
K = mha.split_heads(mha.W_k(y))
V = mha.split_heads(mha.W_v(y))

In [20]:
Q.size()

torch.Size([1, 6, 10, 6])

In [21]:
K.size()

torch.Size([1, 6, 5, 6])

In [22]:
V.size()

torch.Size([1, 6, 5, 6])

In [23]:
K.transpose(-2, -1).size()

torch.Size([1, 6, 6, 5])

In [24]:
scores = torch.matmul(Q, K.transpose(-2, -1))
scores.size()

torch.Size([1, 6, 10, 5])

In [26]:
probs = torch.softmax(scores, dim=-1)

In [27]:
probs.size()

torch.Size([1, 6, 10, 5])

In [28]:
out = torch.matmul(probs, V)
out.size()

torch.Size([1, 6, 10, 6])

In [29]:
final = mha.W_o(mha.combine_heads(out))
final.size()

torch.Size([1, 10, 36])

In [30]:
# Ok, I finally understand decoders.
# No need to code up the full implementation. Boring.