# 1. Base Model

In [5]:
import math
import torch
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, X):
        # X shape : (batch_size, seq_len, hidden_dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
       
        attention_weights = torch.matmul(
            # K needs to be (batch_size, hidden_dim, seq_len)
            Q, K.transpose(1, 2)
       )
        # output (batch, seq, seq)

        attention_weights = torch.softmax(attention_weights / math.sqrt(self.hidden_dim), dim = -1)

        # (batch, seq, hidden)
        output = torch.matmul(attention_weights, V)

        return output

X = torch.rand(3, 2, 4)

self_att_net = SelfAttentionV1(4)
self_att_net(X)

        


tensor([[[ 0.0035, -0.8721, -0.4094, -0.0316],
         [ 0.0035, -0.8722, -0.4094, -0.0316]],

        [[-0.1021, -0.7288, -0.3341, -0.0718],
         [-0.1018, -0.7292, -0.3345, -0.0717]],

        [[ 0.0352, -0.8378, -0.3445, -0.1136],
         [ 0.0345, -0.8374, -0.3438, -0.1138]]], grad_fn=<UnsafeViewBackward0>)

## Efficiency Imporved

In [6]:
class SelfAttentionV2(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.proj = nn.Linear(hidden_dim, hidden_dim*3) # concat Q, K, V
    
    def forward(self, X):
        # X shape : (batch_size, seq_len, hidden_dim)
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim = -1)
       
        attention_weights = torch.matmul(
            # K needs to be (batch_size, hidden_dim, seq_len)
            Q, K.transpose(1, 2)
       )
        # output (batch, seq, seq)

        attention_weights = torch.softmax(attention_weights / math.sqrt(self.hidden_dim), dim = -1)

        # (batch, seq, hidden)
        output = torch.matmul(attention_weights, V)

        return output

X = torch.rand(3, 2, 4)

self_att_net = SelfAttentionV1(4)
self_att_net(X)

tensor([[[ 0.3425, -0.7286, -0.6025, -0.2077],
         [ 0.3430, -0.7251, -0.5989, -0.2084]],

        [[ 0.4270, -0.6466, -0.5041, -0.2572],
         [ 0.4263, -0.6486, -0.5052, -0.2594]],

        [[ 0.2814, -0.7100, -0.6775,  0.0388],
         [ 0.2832, -0.7092, -0.6776,  0.0408]]], grad_fn=<UnsafeViewBackward0>)

## Adding dropout, mask, output_proj

In [7]:
class SelfAttentionV3(nn.Module):
    def __init__(self, hidden_dim, dropout_rate = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        self.attention_dropout = nn.Dropout(dropout_rate)

    def forward(self, X, attention_mask = None):
        # X shape : (batch_size, seq_len, hidden_dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
       
        attention_weights = torch.matmul(
            # K needs to be (batch_size, hidden_dim, seq_len)
            Q, K.transpose(1, 2)
       )
        # output (batch, seq, seq)

        attention_weights = attention_weights / math.sqrt(self.hidden_dim)
        if attention_mask is not None:
            attention_weights = attention_weights.marked_fill(
                attention_mask == 0, float('-inf')
            )
        attention_weights = torch.sofmax(attention_weights, dim = -1)
        attention_weights = self.attention_dropout(attention_weights)
        # (batch, seq, hidden)
        attention_results = torch.matmul(attention_weights, V)

        output = self.output_proj(attention_results)
        return output



## Multi-Head Self-Attention

In [10]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_dim, head_dim, head_num, dropout_rate = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim//head_num
        

        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)

        self.attention_dropout = nn.Dropout(dropout_rate)

    def forward(self, X, attention_mask = None):
        # X(b, s, h)

        batch, seq, _ = X.size()

        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        # (b, s, h) => (b, head_num, s, head_dim)
        q_state = Q.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)
        k_state = K.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)
        v_state = V.view(batch, seq, self.head_num, self.head_dim).transpose(1,2)

        attention_weight = torch.matmul(
            q_state, k_state.transpose(-1, -2)
        )/math.sqrt(self.head_dim)

        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0, float("-inf")
            )
        
        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.attention_dropout(attention_weight)
        output_mid = torch.matmul(
            attention_weight*v_state
        ) # (b, head_num, s, head_dim)

        output_mid = output_mid.transpose(1,2).contiguous() # (b, s, head_num, head_dim)
        output_mid = output_mid.view(batch, seq, -1) # fill out last dim with h = head_dim*head_num

        output = self.output_proj(output_mid)

        return output



