# Scaled Dot Product Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from math import sqrt

In [4]:
def attention(q, k, v, d_k, mask=None, dropout=None):
    scaled_dot = torch.matmul(q, k.transpose(-2, -1)) / sqrt(d_k)
    if mask is not None:
        scaled_dot = scaled_dot.masked_fill(mask == 0, -1e9)
    scaled_dot = F.softmax(scaled_dot, dim=-1)
    if dropout is not None:
        scaled_dot = dropout(scaled_dot)
    output =  torch.matmul(scaled_dot, v)
    return output

In [5]:
q = torch.tensor([[0, 10, 0], [1, 2, 3]], dtype=torch.float32)
k = torch.tensor([[10, 0, 0], [0, 10, 0]], dtype=torch.float32)
v = torch.tensor([[1, 0, 1], [10, 0, 5]], dtype=torch.float32)
mask = torch.tensor([[1,0],[1,1]])
output = attention(q, k, v, k.shape[-1], mask=mask)

In [6]:
print(output)

tensor([[1.0000, 0.0000, 1.0000],
        [9.9721, 0.0000, 4.9876]])
