# Classic

In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [14]:
class QKVMatrices(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

    def forward(self, x) -> Tensor:
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        return q, k, v

In [53]:
def attention(q, k, v) -> Tensor:
    d_model = q.shape[-1]
    k_t = k.transpose(-1, -2)
    scores = (q @ k_t) / math.sqrt(d_model)
    attention_weights = F.softmax(scores, dim=-1) @ v
    return attention_weights

In [54]:
x = torch.rand(2, 4, 2)

In [55]:
q, k, v = QKVMatrices(2)(x)

In [56]:
q

tensor([[[-0.8103,  0.1402],
         [-0.8005, -0.0514],
         [-0.8954,  0.1716],
         [-0.8005,  0.0955]],

        [[-0.8227,  0.2479],
         [-0.7278,  0.0112],
         [-0.7923,  0.1930],
         [-0.8169,  0.0800]]], grad_fn=<ViewBackward0>)

In [57]:
k

tensor([[[ 0.1049, -0.1987],
         [ 0.1564, -0.4886],
         [ 0.0665, -0.1348],
         [ 0.1196, -0.2679]],

        [[ 0.0735, -0.0345],
         [ 0.1669, -0.4089],
         [ 0.0982, -0.1229],
         [ 0.1176, -0.2879]]], grad_fn=<ViewBackward0>)

In [58]:
v

tensor([[[0.6289, 1.0547],
         [0.6324, 0.9571],
         [0.6069, 1.2646],
         [0.6316, 1.0144]],

        [[0.6251, 1.1256],
         [0.6507, 0.8126],
         [0.6332, 1.0336],
         [0.6275, 1.0464]]], grad_fn=<ViewBackward0>)

In [59]:
attention(q, k, v)

tensor([[[0.6247, 1.0760],
         [0.6248, 1.0743],
         [0.6247, 1.0765],
         [0.6247, 1.0755]],

        [[0.6338, 1.0093],
         [0.6340, 1.0066],
         [0.6338, 1.0087],
         [0.6339, 1.0076]]], grad_fn=<UnsafeViewBackward0>)

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module): 
    
    def __init__(self, d_model=2):
        super().__init__()
        
        self.d_model=d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        sims = torch.matmul(q, k.transpose(-1, -2))

        scaled_sims = sims / torch.tensor(k.size(-1)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        attention_percents = F.softmax(scaled_sims, dim=-2)
        attention_scores = torch.matmul(attention_percents, v)
        
        return attention_scores

In [43]:
x = torch.rand(1, 4, 2)

In [45]:
Attention()(x, x, x)

tensor([[[ 0.1266, -0.4618],
         [ 0.1191, -0.4335],
         [ 0.1269, -0.4630],
         [ 0.1529, -0.5620]]], grad_fn=<UnsafeViewBackward0>)

In [67]:
class MultiHeadAttention(nn.Module): 
    
    def __init__(self, d_model: int =2, n_heads: int = 2):
        super().__init__()
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_heads = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_z = nn.Linear(d_model, d_model, bias=False)

    def pre_attention_reshape(self, x):
        # [B, L, E] -> [B, H, L, HD]
        B, L, E = x.shape  # B: batch size, L: SEQ_LEN, E: D_MODEL
        x = x.contiguous().view(B, L, self.n_heads, self.d_heads)
        x = x.transpose(1, 2)
        return x

    def post_attention_reshape(self, x):
        # [B, H, L, HD] -> [B, L, E]
        B, H, L, HD = x.shape  # B: batch size, H: N_HEADS, L: SEQ_LEN, HD: D_HEADS
        x = x.transpose(2, 1)
        x = x.contiguous().view((B, L, self.d_model))
        return x
        
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.pre_attention_reshape(self.W_q(encodings_for_q))
        k = self.pre_attention_reshape(self.W_k(encodings_for_k))
        v = self.pre_attention_reshape(self.W_v(encodings_for_v))
        sims = torch.matmul(q, k.transpose(-1, -2))

        scaled_sims = sims / torch.tensor(k.size(-1)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        attention_percents = F.softmax(scaled_sims, dim=-2)
        attention_scores = torch.matmul(attention_percents, v)
        
        attention_scores = self.post_attention_reshape(attention_scores)
        z = self.W_z(attention_scores)
        
        return z

In [70]:
x = torch.rand(2, 6, 4)
x

tensor([[[0.7326, 0.3260, 0.2582, 0.9637],
         [0.7328, 0.7659, 0.9104, 0.0196],
         [0.2067, 0.9213, 0.2886, 0.7394],
         [0.9466, 0.4678, 0.6391, 0.6604],
         [0.4453, 0.9252, 0.7580, 0.3311],
         [0.0724, 0.2072, 0.4537, 0.7393]],

        [[0.5149, 0.1084, 0.3849, 0.0651],
         [0.0261, 0.0598, 0.7980, 0.2159],
         [0.9039, 0.2390, 0.5307, 0.4736],
         [0.2737, 0.7136, 0.5887, 0.8175],
         [0.6334, 0.8573, 0.1050, 0.0070],
         [0.7399, 0.5876, 0.6764, 0.0661]]])

In [71]:
MultiHeadAttention(d_model=4, n_heads=2)(x, x, x)

tensor([[[-0.1354,  0.1081,  0.1995, -0.3149],
         [-0.1302,  0.1288,  0.2171, -0.3159],
         [-0.1260,  0.1251,  0.2106, -0.3118],
         [-0.1329,  0.1156,  0.2055, -0.3146],
         [-0.1253,  0.1319,  0.2172, -0.3136],
         [-0.1312,  0.1200,  0.2094, -0.3177]],

        [[-0.1202,  0.0118,  0.1015, -0.2252],
         [-0.1168,  0.0184,  0.1064, -0.2240],
         [-0.1188,  0.0105,  0.0991, -0.2229],
         [-0.1125,  0.0177,  0.1026, -0.2176],
         [-0.1164,  0.0141,  0.1009, -0.2195],
         [-0.1156,  0.0170,  0.1038, -0.2205]]], grad_fn=<UnsafeViewBackward0>)