In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

$\begin{equation*}\Large
Attention(Q,K,V) = softmax\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V
\end{equation*}$

In [6]:
def attention(q, k, v, mask=None): # q,k[B, N, seq_len, depth]
    depth = q.size(-1)
    scale = depth**-0.5
    q = q*scale

    attn = torch.matmul(q, k.transpose(-2,-1)) # [B, N, seq_len, seq_len]
    attn = attn.softmax(dim=-1)
    # dropout
    out = torch.matmul(attn, v)    # [B, N, seq_len, depth]
    return out, attn

In [7]:
class mhsa(nn.Module):  # multi headed self attention
    def __init__(self, emb_dim=512, nheads=8):
        super().__init__()
        self.nheads = nheads
        self.to_qkv = nn.Linear(emb_dim, 3*emb_dim)
        self.out_proj = nn.Linear(emb_dim, emb_dim)

    def forward(self, x):   # [B, seq_len, emb_dim]
        batch_size, seq_len = x.shape[:2]
        q,k,v = self.to_qkv(x).chunk(3, dim=-1) # [B, seq_len, emb_dim]
        
        q = q.view(batch_size, seq_len, self.nheads, -1)
        k = k.view_as(q)
        v = v.view_as(q)     # [B, seq_len, N, depth]
        
        q = q.transpose(-2,-3)
        k = k.transpose(-2,-3)
        v = v.transpose(-2,-3)  # [B, N, seq_len, depth]
        
        out, attn = attention(q,k,v)
        
        out = out.transpose(-2,-3)  # [B, seq_len, N, depth]
        
        out = out.contiguous().view(batch_size, seq_len, -1)
        out = self.out_proj(out)
        return out, attn

In [8]:
atn = mhsa(emb_dim=512).to(device)

In [9]:
x = torch.randn(8, 64, 512).to(device)

In [10]:
%%time
out, attn = atn(x)
out.shape, attn.mean(dim=1).shape

CPU times: user 3.68 ms, sys: 77 µs, total: 3.75 ms
Wall time: 2.96 ms


(torch.Size([8, 64, 512]), torch.Size([8, 64, 64]))

In [11]:
tmha = nn.MultiheadAttention(embed_dim=512, num_heads=8).to(device)

In [12]:
y = x.transpose(1,0)
y.shape

torch.Size([64, 8, 512])

In [13]:
with torch.no_grad():
    for w,nw in zip(tmha.parameters(), atn.parameters()):
        w.copy_(nw)

In [14]:
o,a = tmha(y,y,y)
o.shape, a.shape

(torch.Size([64, 8, 512]), torch.Size([8, 64, 64]))

In [23]:
(o == out.transpose(1,0)).all()

tensor(True, device='cuda:0')