In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

$\begin{equation*}\Large
Attention(Q,K,V) = softmax\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V
\end{equation*}$

In [101]:
def scaled_dot_product_attention(q, k, v, mask=None): # q,k[B, nheads, seq_len, depth]
    depth = q.size(-1)
    scale = depth**-0.5
    q = q*scale

    attn = torch.matmul(q, k.transpose(-2,-1)) # [B, nheads, seq_len_q, seq_len_k]
    attn = attn.softmax(dim=-1)
    # dropout
    out = torch.matmul(attn, v)    # [B, nheads, seq_len_q, depth_v]
    return out, attn

In [102]:
class mha(nn.Module):  # multi headed attention
    def __init__(self, emb_dim=512, nheads=8):
        super().__init__()
        self.nheads = nheads
        self.qw = nn.Linear(emb_dim, emb_dim)
        self.kw = nn.Linear(emb_dim, emb_dim)
        self.vw = nn.Linear(emb_dim, emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)

    def forward(self, query, key, value):   # [B, seq_len, emb_dim]
        batch_size, seq_len, emb_dim = query.shape
        q,k,v = self.qw(query), self.kw(key), self.vw(value) # [B, seq_len, emb_dim]
        
        q = q.view(batch_size, seq_len, self.nheads, -1)
        k = k.view_as(q)
        v = v.view_as(q)     # [B, seq_len, nheads, depth]
        
        q = q.transpose(-2,-3)
        k = k.transpose(-2,-3)
        v = v.transpose(-2,-3)  # [B, nheads, seq_len, depth]
        
        out, attn = scaled_dot_product_attention(q,k,v)   # [B, nheads, seq_len, depth]
        
        out = out.transpose(-2,-3)  # [B, seq_len, nheads, depth]
        
        out = out.reshape(batch_size, seq_len, -1)
        out = self.out_proj(out)
        return out, attn

In [174]:
def scaled_dot_product_attention(q, k, v, mask=None): # q,k[B, nheads, depth, seq_len]
    depth = q.size(-1)
    scale = depth**-0.5
    q = q*scale

    attn = torch.matmul(q.transpose(-2,-1), k) # [B, nheads, seq_len_q, seq_len_k]
    attn = attn.softmax(dim=-1)
    # dropout
    out = torch.matmul(v, attn.transpose(-2,-1))    # [B, nheads, depth_v, seq_len_q]
    return out, attn

In [175]:
class mha(nn.Module):  # multi headed attention
    def __init__(self, emb_dim=512, nheads=8):
        super().__init__()
        self.nheads = nheads
        self.qw = nn.Conv1d(emb_dim, emb_dim, kernel_size=1)
        self.kw = nn.Conv1d(emb_dim, emb_dim, kernel_size=1)
        self.vw = nn.Conv1d(emb_dim, emb_dim, kernel_size=1)
        self.out_proj = nn.Conv1d(emb_dim, emb_dim, kernel_size=1)

    def forward(self, query, key, value):   # [B, emb_dim, seq_len]
        batch_size, emb_dim, seq_len,  = query.shape
        q,k,v = self.qw(query), self.kw(key), self.vw(value) # [B, emb_dim, seq_len]

        q = q.view(batch_size, self.nheads, -1, seq_len)
        k = k.view_as(q)
        v = v.view_as(q)     # [B, nheads, depth, seq_len]

        out, attn = scaled_dot_product_attention(q,k,v)   # [B, nheads, depth, seq_len]
        
        out = out.reshape(batch_size, -1, seq_len)
        out = self.out_proj(out)
        return out, attn

In [176]:
atn = mha(emb_dim=512)#.to(device)

In [177]:
x = torch.randn(8, 512, 64)#.to(device)

In [178]:
%%timeit
out, attn = atn(x,x,x)
out.shape, attn.mean(dim=1).shape

9.61 ms ± 47.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
tmha = nn.MultiheadAttention(embed_dim=512, num_heads=8)#.to(device)

In [72]:
[w.shape for w in tmha.parameters()]

[torch.Size([1536, 512]),
 torch.Size([1536]),
 torch.Size([512, 512]),
 torch.Size([512])]

In [12]:
y = x.transpose(1,0)
y.shape

torch.Size([64, 8, 512])

In [13]:
with torch.no_grad():
    for w,nw in zip(tmha.parameters(), atn.parameters()):
        w.copy_(nw)

In [179]:
%%timeit
o,a = tmha(y,y,y)
o.shape, a.shape

10.3 ms ± 273 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
(o == out.transpose(1,0)).all()

tensor(True)