In [26]:
import torch
from torch import nn 

$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

Each of Q, K, V has shape (batch_size, seq_length, num_features)

Components: 
- Scaled dot product attention 
- Attention Head 
- Multi head attention
- Position Encoding 
- Feed forward 
- Residual 
- Transformer Encoder Layer
- Transformer Encoder 

In [2]:
from torch import Tensor 
import torch.nn.functional as f 

def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor): 
    temp = query.bmm(key.transpose(1, 2)) # performs a batch matrix multiplication 
    scale = key.size(-1) ** 0.5 
    softmax = f.softmax(temp / scale, dim=-1)
    return softmax.bmm(value)


In [24]:
def scaled_dot_product_attention(key: Tensor, query: Tensor, value: Tensor): 
    temp = query.bmm(key.transpose(1,2))
    scale = key.size(-1) ** 0.5 
    softmax = f.softmax(temp / scale)
    return softmax.bmm(value)

In [25]:
data = [[[1,2,3], [4,5,6], [7,8,9]], [[10,11,12], [13,14,15], [16,17,18]], [[19,20,21], [22,23,24], [25,26,27]]]


V = torch.tensor(data, dtype=torch.float)
scaled_dot_product_attention(V, V, V)

class AttentionHead(nn.Module): 
    def __init__(self, dim_in: int, dim_q: int, dim_k: int): 
        super().__init__()
        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)
    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))
    
class MultiHeadAttention(nn.Module): 
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int): 
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: 
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )


  softmax = f.softmax(temp / scale)


tensor([[[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[66., 69., 72.],
         [66., 69., 72.],
         [66., 69., 72.]]])

Residual Network: a layer with skip connections 

In [None]:
class AttentionHead(nn.Module): 
    def __init__(self, dim_in: int, dim_q: int, dim_k: int): 
        super.__init__()
        self.keys = nn.Linear(dim_in, dim_k)
        self.query = nn.Linear(dim_in, dim_q)
        self.values = nn.Linear(dim_in, dim_k) 
    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: 
        return scaled_dot_product_attention(self.query(query), self.keys(key), self.values(value))
    
class MultiHeadAttention(nn.Module): 
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int): 
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)  
    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: 
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

def position_encoding(
    seq_len: int, dim_model: int, device: torch.device = torch.device("gpu")
) -> Tensor: 
    pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
    dim = torch.arange(dim_model, dtype=torch.flaot, device=device).reshape(1, 1, -1)
    phase = pos / (1e4 **  (dim / dim_model))

    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

def feed_forward(dim_input: int=512, dim_feedforward: int=2048) -> nn.Module: 
    return nn.Sequential(
        nn.Linear(dim_input, dim_feedforward),
        nn.ReLU(),
        nn.Linear(dim_feedforward, dim_input),
    )
    
class Residual(nn.Module): 
    def  __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1): 
        super().__init__()
        self.sublayer = sublayer 
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *tensors: Tensor) -> Tensor: 
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))
    
class TransformerEncoderLayer(nn.Module): 
    def __init__(
        self, 
        dim_model: int = 512, 
        num_heads: int = 6, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1
    ): 
        super().__init__() 
        dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention = Residual(
            MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), 
            dimension = dim_model, 
            dropout = dropout,
        )

        self.feed_forward = Residual(
            feed_forward(dim_model, dim_feedforward), 
            dimension=dim_model, 
            dropout=dropout,
        )

    def forward(self, src: Tensor) -> Tensor: 
        src = self.attention(src, src, src)
        return self.feed_forward(src)

class TransformerEncoder(nn.Module): 
    def __init__(
        self, 
        num_layers: int = 6, 
        dim_model: int = 512, 
        num_heads: int = 8, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1, 
    ): 
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    dim_model, num_heads, dim_feedforward, dropout
                ) for _ in range(num_layers)
            ]
        )
    
    def forward(self, src: Tensor) -> Tensor: 
        seq_len, dimension = src.size(1), src.size(2)
        src += position_encoding(seq_len, dimension)
        for layer in self.layers: 
            src = layer(src)
        return src