https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html


https://www.youtube.com/watch?v=U0s0f995w14


https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [41]:
class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        
        assert (self.head_dim * num_heads == model_dim), "Embed size needs to be div by heads!"
        
        self.linearK = nn.Linear(model_dim, model_dim, bias=False)
        self.linearV = nn.Linear(model_dim, model_dim, bias=False)
        self.linearQ = nn.Linear(model_dim, model_dim, bias=False)
        self.fc_out = nn.Linear(model_dim, model_dim, bias=False)
        
    def forward(self, K, V, Q, mask):
        N = Q.shape[0]
        
        K_len = K.shape[1]
        V_len = V.shape[1]
        Q_len = Q.shape[1]

        K = self.linearK(K).reshape((N, K_len, self.num_heads, self.head_dim))
        V = self.linearV(V).reshape((N, V_len, self.num_heads, self.head_dim))
        Q = self.linearQ(Q).reshape((N, Q_len, self.num_heads, self.head_dim))

        energy = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
        
        energy /= self.head_dim ** (1/2)

        print(energy.shape)
        print(mask.shape)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy, dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, V]).reshape((N, Q_len, self.model_dim))

        out = self.fc_out(out)
        
        return out

In [42]:
class SubLayerMHA(nn.Module):
    def __init__(self, model_dim, num_heads, dropout):
        super(SubLayerMHA, self).__init__()
        self.mha = MultiHeadAttention(model_dim, num_heads)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(model_dim)
        
    def forward(self, K, V, Q, mask):
        attn_out = self.mha(K, V, Q, mask)
        out = self.norm(Q + self.dropout(attn_out))
        
        return out
    
    
class SubLayerFFN(nn.Module):
    def __init__(self, model_dim, ff_dim, dropout):
        super(SubLayerFFN, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, model_dim),
        )
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(model_dim)
        
    def forward(self, x):
        ffn_out = self.ffn(x)
        out = self.norm(x + self.dropout(ffn_out))

        return out

In [43]:
class EncoderBlock(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim, dropout):
        super(EncoderBlock, self).__init__()
        self.mha = SubLayerMHA(model_dim, num_heads, dropout)
        self.ffn = SubLayerFFN(model_dim, ff_dim, dropout)
        
    def forward(self, x, pad_mask):
        x = self.mha(x, x, x, pad_mask)
        x = self.ffn(x)
        return x

In [44]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, model_dim, num_heads, ff_dim, dropout, num_layers):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, model_dim)
        self.pe = PositionalEncoding(model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList(
            [EncoderBlock(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        
    def forward(self, x, pad_mask):
        out = self.dropout(self.pe(self.embed(x)))
        for layer in self.layers:
            out = layer(out, pad_mask)
        return out

In [45]:
class DecoderBlock(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim, dropout):
        super(DecoderBlock, self).__init__()
        self.masked_mha = SubLayerMHA(model_dim, num_heads, dropout)
        self.mha = SubLayerMHA(model_dim, num_heads, dropout)
        self.ffn = SubLayerFFN(model_dim, ff_dim, dropout)
            
    def forward(self, enc_out, x, pad_mask, mha_mask):
        out = self.masked_mha(x, x, x, mha_mask)
        out = self.mha(enc_out, enc_out, x, pad_mask)
        out = self.ffn(out)
        return out

In [46]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, model_dim, num_heads, ff_dim, dropout, num_layers):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, model_dim)
        self.pe = PositionalEncoding(model_dim)
        self.dropout = nn.Dropout(dropout)
        self.layers = nn.ModuleList(
            [DecoderBlock(model_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        
    def forward(self, enc_out, x, pad_mask, mha_mask):
        out = self.dropout(self.pe(self.embed(x)))
        for layer in self.layers:
            out = layer(enc_out, out, pad_mask, mha_mask)
        return out

In [47]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, pad_size,  
                 model_dim=512, 
                 num_heads=8, 
                 ff_dim=2048, 
                 dropout=0.1, 
                 num_layers=6, 
                 device="cuda"):
        super(Transformer, self).__init__()
        
        self.pad_size = pad_size
        self.device = device
        self.encoder = Encoder(src_vocab_size, model_dim, num_heads, ff_dim, dropout, num_layers)
        self.decoder = Decoder(trg_vocab_size, model_dim, num_heads, ff_dim, dropout, num_layers)
        
    def make_pad_mask(self, input):
        pad_mask = (input != self.pad_size).unsqueeze(1).unsqueeze(2)
        return pad_mask.to(self.device)
    
    def make_mha_mask(self, output):
        N, output_len = output.shape
        mha_mask = torch.tril(torch.ones((output_len, output_len))).expand(N, 1, output_len, output_len)
        return mha_mask.to(self.device)
    
    def forward(self, input, output):
        inp_pad_mask = self.make_pad_mask(input)
        out_pad_mask = self.make_pad_mask(output)
        mha_mask = self.make_mha_mask(output)
        enc_out = self.encoder(input, inp_pad_mask)
        dec_out = self.decoder(enc_out, output, out_pad_mask, mha_mask)
        return dec_out    

In [48]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

input = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
output = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

pad_size = 0
trg_pad_idx = 0
src_vocab_size = 30
trg_vocab_size = 30

model = Transformer(src_vocab_size, trg_vocab_size, pad_size, num_layers=1).to(device)

dec_out = model(input, output)


cuda
torch.Size([2, 8, 9, 9])
torch.Size([2, 1, 1, 9])
torch.Size([2, 8, 8, 8])
torch.Size([2, 1, 9, 9])


RuntimeError: The size of tensor a (9) must match the size of tensor b (8) at non-singleton dimension 3