In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F

import numpy as np
import math

USE_GPU = True
dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print_every = 100
print('using device:', device)

using device: cuda


In [13]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads  
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, attn_mask=False):
        # x: [B x N x embed_dim]
        B, N, _ = x.shape

        # [B x N x embed_dim]
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # [B x H x N x head_dim]
        Q = Q.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        # K: [B x H x head_dim x N]
        A = (Q @ K.transpose(-1, -2)) / math.sqrt(self.head_dim) # [B x H x N x N]

        if attn_mask:
            mask = torch.tril(torch.ones((N, N), device=x.device))
            A = A.masked_fill(mask == 0, float('-inf'))

        A = F.softmax(A, dim=-1)

        # [B x H x N x head_dim]
        Y = A @ V

        # [B x N x embed_dim]
        Y = Y.transpose(1, 2).contiguous().reshape(B, N, self.embed_dim)

        out = self.out_proj(Y)
        return out


In [14]:
batch_size = 5
num_tokens = 11
embed_dim = 16
num_heads = 4

model = MultiHeadedAttention(embed_dim, num_heads)
x = torch.randn((batch_size, num_tokens, embed_dim))

out = model(x, attn_mask=True)
out.shape

torch.Size([5, 11, 16])