In [2]:
import numpy as np
import torch
import torch.nn as nn

In [None]:
class Constants:
    D: int = 768   # dimension of features
    D_k: int = 64  # dimension of key and query vectors 
    D_v: int = 64  # dimension of value vector
    H: int = 12    # number of attention heads 

    std: 0.1

In [None]:
class Attention(nn.Module):
    def __init__(self, const: Constants):
        super().__init__()
        self.W_q = nn.Parameter(torch.Tensor((const.H, const.D, const.D_k)))
        self.W_k = nn.Parameter(torch.Tensor((const.H, const.D, const.D_k)))
        self.W_v = nn.Parameter(torch.Tensor((const.H, const.D, const.D_v)))
        
        self.W_o = nn.Parameter(torch.Tensor((const.H, const.D_v, const.D)))

        nn.init.normal_(self.W_q, std = const.std)
        nn.init.normal_(self.W_k, std = const.std)
        nn.init.normal_(self.W_v, std = const.std)
        nn.init.normal_(self.W_o, std = const.std)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        Q = X @ self.W_q
        K = X @ self.W_k

        V = X @ self.W_v

        A = nn.Softmax((Q @ torch.transpose(K)) / np.sqrt(self.D))

        return (A @ V) @ self.W_o 
    


