In [2]:
from torch import nn
import torch
import torch.nn.functional as F
from math import sqrt

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, input_dim, dropout=0.1):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dim_heads = embed_dim // num_heads     # dim_heads aka d_k

        self.q_lin = nn.Linear(input_dim, embed_dim)
        self.k_lin = nn.Linear(input_dim, embed_dim)
        self.v_lin = nn.Linear(input_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        num_heads, dim_heads = self.num_heads, self.dim_heads

        q = self.q_lin(q).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)
        k = self.k_lin(k).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)
        v = self.v_lin(v).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)

        scores = attention(q, k, v, dim_heads, mask=mask, dropout=self.dropout)
        
        scores = scores.transpose(1, 2).contiguous().reshape(batch_size, -1, self.embed_dim)

        output = self.out_proj(scores)

        return output


In [4]:
# this is qingyuan's function
def attention(q, k, v, d_k, mask=None, dropout=None):
    scaled_dot = torch.matmul(q, k.transpose(-2, -1)) / sqrt(d_k)
    if mask is not None:
        scaled_dot = scaled_dot.masked_fill(mask == 0, -1e9)
    scaled_dot = F.softmax(scaled_dot, dim=-1)
    if dropout is not None:
        scaled_dot = dropout(scaled_dot)
    output =  torch.matmul(scaled_dot, v)
    return output

In [5]:
embed_dim = 512
num_heads = 8

q = torch.tensor([[0, 10, 0]], dtype=torch.float32)
k = torch.tensor([[0, 10, 0]], dtype=torch.float32)
v = torch.tensor([[0, 10, 0]], dtype=torch.float32)
input_dim = 3

mh = MultiHeadAttention(num_heads, embed_dim, input_dim)
scores = mh.forward(q, k, v)
print(scores)

tensor([[[ 7.1008e-01,  4.1344e+00,  4.2444e+00, -2.2098e-01,  2.4778e+00,
           1.8659e+00, -3.6529e-01, -1.8282e-01, -2.3154e+00,  2.8329e+00,
           5.9769e+00, -2.4970e+00,  9.0717e-01, -1.3421e+00, -1.3616e+00,
          -1.5755e+00,  5.6827e+00,  9.1596e-01, -1.2586e+00,  5.6572e-01,
           8.2308e-01,  2.1544e+00,  3.3808e+00, -1.1160e+00,  1.8727e+00,
           1.1822e+00,  2.2700e+00,  9.9076e-01,  1.6973e+00,  6.4900e-01,
          -7.5532e-01,  1.9348e+00, -1.7630e+00, -3.0904e+00, -1.4175e+00,
           6.9837e-01,  2.9211e+00,  2.9955e+00, -1.1018e+00, -9.8237e-01,
          -1.2206e+00,  8.6352e-01, -1.9078e+00,  7.8156e-01, -1.8281e+00,
          -7.4415e-01, -2.3792e+00, -2.6369e+00,  1.2690e+00,  1.0197e+00,
           3.6274e-01,  1.6704e+00, -3.4851e+00, -1.5627e-01,  2.7131e+00,
           7.3606e-01,  2.3178e+00,  2.8196e+00, -1.5758e+00,  2.8286e+00,
           1.0097e+00, -2.9510e-01,  7.4239e-01,  2.6189e-01, -4.6646e+00,
          -3.9559e-01, -8