In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

$\begin{equation*}\Large
Attention(Q,K,V) = softmax\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V
\end{equation*}$

In [9]:
def scaled_dot_product_attention(q, k, v, attn_mask=None): # q,k[B, nheads, seq_len, depth]
    depth = q.size(-1)
    scale = depth**-0.5
    q = q*scale

    attn = torch.matmul(q, k.transpose(-2,-1)) # [B, nheads, seq_len_q, seq_len_k]
    if attn_mask is not None:
        attn += attn_mask
    attn = attn.softmax(dim=-1)
    # dropout
    out = torch.matmul(attn, v)    # [B, nheads, seq_len_q, depth_v]
    return out, attn

In [10]:
class mha(nn.Module):  # multi headed attention
    def __init__(self, emb_dim=512, nheads=8):
        super().__init__()
        self.nheads = nheads
        self.qw = nn.Linear(emb_dim, emb_dim)
        self.kw = nn.Linear(emb_dim, emb_dim)
        self.vw = nn.Linear(emb_dim, emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)

    def forward(self, query, key, value, attn_mask=None):   # [B, seq_len, emb_dim]
        batch_size, seq_len, emb_dim = query.shape
        q,k,v = self.qw(query), self.kw(key), self.vw(value) # [B, seq_len, emb_dim]
        
        q = q.view(batch_size, seq_len, self.nheads, -1)
        k = k.view_as(q)
        v = v.view_as(q)     # [B, seq_len, nheads, depth]
        
        q = q.transpose(-2,-3)
        k = k.transpose(-2,-3)
        v = v.transpose(-2,-3)  # [B, nheads, seq_len, depth]
        
        out, attn = scaled_dot_product_attention(q,k,v,attn_mask)   # [B, nheads, seq_len, depth]
        
        out = out.transpose(-2,-3)  # [B, seq_len, nheads, depth]
        
        out = out.reshape(batch_size, seq_len, -1)
        out = self.out_proj(out)
        return out, attn

In [11]:
atn = mha(emb_dim=512)#.to(device)

In [12]:
x = torch.randn(8, 64, 512)#.to(device)

In [13]:
# %%timeit
out, attn = atn(x,x,x)
out.shape, attn.mean(dim=1).shape

(torch.Size([8, 64, 512]), torch.Size([8, 64, 64]))

In [14]:
tmha = nn.MultiheadAttention(embed_dim=512, num_heads=8)#.to(device)

In [15]:
y = x.transpose(1,0)
y.shape

torch.Size([64, 8, 512])

In [16]:
# %%timeit
o,a = tmha(y,y,y)
o.shape, a.shape

(torch.Size([64, 8, 512]), torch.Size([8, 64, 64]))

In [17]:
def get_ffn(d_model, dp_rate=0.1):
    return nn.Sequential(OrderedDict([
            ('lin1', nn.Linear(d_model, d_model*4)),
            ('activation', nn.ReLU()),
            ('dropout', nn.Dropout(dp_rate)),
            ('lin2', nn.Linear(d_model*4, d_model))
        ]))

In [31]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, nheads=8, dp_rate=0.1):
        super().__init__()
        self.attn = mha(emb_dim=d_model, nheads=nheads)
        self.dropout1 = nn.Dropout(dp_rate)
        self.ln1 = nn.LayerNorm(d_model)
        self.ffn = get_ffn(d_model, dp_rate)
        self.dropout2 = nn.Dropout(dp_rate)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x, attn_mask=None):   # [B, seq_len, d_model]
        x = self.ln1(x + self.dropout1(self.attn(x,x,x, attn_mask=attn_mask)[0]))
        x = self.ln2(x + self.dropout2(self.ffn(x)))
        return x

In [32]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=512, nheads=8, dp_rate=0.1):
        super().__init__()
        self.attn1 = mha(emb_dim=d_model, nheads=nheads)
        self.dropout1 = nn.Dropout(dp_rate)
        self.ln1 = nn.LayerNorm(d_model)

        self.attn2 = mha(emb_dim=d_model, nheads=nheads)
        self.dropout2 = nn.Dropout(dp_rate)
        self.ln2 = nn.LayerNorm(d_model)

        self.ffn = get_ffn(d_model, dp_rate)
        self.dropout3 = nn.Dropout(dp_rate)
        self.ln3 = nn.LayerNorm(d_model)
    
    def forward(self, x, enc_output, attn_mask=None, memory_mask=None, padding_mask=None):   # [B, seq_len, d_model]
        x = self.ln1(x + self.dropout1(self.attn1(x,x,x, attn_mask=attn_mask)[0]))
        x = self.ln1(x + self.dropout1(self.attn2(x, enc_output, enc_output, attn_mask=padding_mask)[0]))
        x = self.ln2(x + self.dropout2(self.ffn(x)))
        return x

In [33]:
te = TransformerEncoderLayer()

In [None]:
def reset_parameters(self):
    for p in self.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)