## Single Head Attention Flow

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [27]:
# Generate a matrix in T*d, where T is the number of nodes, d is the dimension of each node's feature
X = torch.randn(10, 16) # T=10 hidden_size=16
W_Q = nn.Linear(16, 16) # hidden_size = 16
W_K = nn.Linear(16, 16) # hidden_size = 16
W_V = nn.Linear(16, 16) # hidden_size = 16

# Query, Key, Value
Q = W_Q(X) # Q: [10, 16]
K = W_K(X) # K: [10, 16]
V = W_V(X) # V: [10, 16]

# Perform Attention Mechanism
# dot product
dot = torch.matmul(Q, K.transpose(-2, -1)) # Q·K^T: [10, 10]

# Scale the dot product
dot = dot / torch.sqrt(torch.tensor(16, dtype=torch.float32))

# Masked Attention (Opt.)
# should be -np.inf before softmax, adj matrix can be used here
mask = torch.zeros(10, 10)
mask[0, 2] = -torch.inf; mask[2, 0] = -torch.inf

# Softmax
attention = torch.softmax(dot + mask, dim=1) # [10, 10]

# It should pass the assert since we masked the attention between node 0 and 2
assert attention[0, 2] == 0 and attention[2, 0] == 0

# MatMul
Z = torch.matmul(attention, V) # [10, 16]

# Output
W_O = nn.Linear(16, 16) # hidden_size = 16
output = W_O(Z) # [10, 16]

## Single Head Attention Class

In [28]:
class SingleHeadAttention(nn.Module):
    def __init__(self, hidden_size=16, attention_dropout_rate=0):
        
        super(SingleHeadAttention, self).__init__()
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)
        self.linear_o = nn.Linear(hidden_size, hidden_size)
        self.scale = hidden_size ** -0.5
        self.att_dropout = nn.Dropout(attention_dropout_rate)

    def forward(self, x, attn_mask=None):
        q = self.linear_q(x)
        k = self.linear_k(x)
        v = self.linear_v(x)
        x = torch.matmul(q, k.transpose(-2, -1))
        x = x * self.scale
        if attn_mask is not None:
            x = x + attn_mask
        x = torch.softmax(x, dim=1)
        x = self.att_dropout(x)
        x = torch.matmul(x, v)
        x = self.linear_o(x)
        return x

# create a single head attention instance the same as the above without mask
single_head_attention = SingleHeadAttention()

# create a random tensor with shape (10, 16)
x = torch.randn(10, 16)

# forward pass
output = single_head_attention(x)

## Multi Head Attention Flow

In [29]:
# Initial Input and Weights
X = torch.randn(10, 16) 

W_Q = nn.Linear(16, 16)
W_K = nn.Linear(16, 16)
W_V = nn.Linear(16, 16)

# Multi Head Attention
num_heads = 4
att_size = 16 // num_heads

Q = W_Q(X).view(10, num_heads, att_size) # [10, 4, 4]
K = W_K(X).view(10, num_heads, att_size) # [10, 4, 4]
V = W_V(X).view(10, num_heads, att_size) # [10, 4, 4]

# permute the Q, K, V to [num_heads, num_samples, num_features] to satisfy torch.matmul's requirement
Q = Q.permute(1, 0, 2)
K = K.permute(1, 0, 2)
V = V.permute(1, 0, 2)

# dot product
dot = torch.matmul(Q, K.transpose(-2, -1))

# Scale the dot product
dot = dot * att_size ** -0.5

# Softmax
attention = torch.softmax(dot, dim=2)

# Calculate Z
Z = torch.matmul(attention, V)
Z = Z.permute(1,0,2) # [num_samples, num_heads, num_features]
Z_concat = Z.contiguous().view(10, -1)  # [10, 16]

# Output
W_O = nn.Linear(16, 16) # hidden_size = 16
output = W_O(Z_concat) # [10, 16]

## Multi Head Attention Class

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size=16, num_heads=4, attention_dropout_rate=0):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.att_size = hidden_size // num_heads
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)
        self.linear_o = nn.Linear(hidden_size, hidden_size)
        self.att_dropout = nn.Dropout(attention_dropout_rate)
        
    def forward(self, x, attn_mask=None):
        q = self.linear_q(x).view(x.size(0), self.num_heads, self.att_size)
        k = self.linear_k(x).view(x.size(0), self.num_heads, self.att_size)
        v = self.linear_v(x).view(x.size(0), self.num_heads, self.att_size)
        
        q = q.permute(1, 0, 2)
        k = k.permute(1, 0, 2)
        v = v.permute(1, 0, 2)
        
        dot = torch.matmul(q, k.transpose(-2, -1))
        dot = dot * self.att_size ** -0.5
        
        if attn_mask is not None:
            dot = dot + attn_mask
        
        attention = torch.softmax(dot, dim=2)
        attention = self.att_dropout(attention)
        
        Z = torch.matmul(attention, v)
        Z = Z.permute(1, 0, 2)
        Z = Z.contiguous().view(x.size(0), -1)
        output = self.linear_o(Z)
        return output

# create a multi head attention instance the same as the above without mask
multi_head_attention = MultiHeadAttention()

# create a random tensor with shape (10, 16)
x = torch.randn(10, 16)

# forward pass
output = multi_head_attention(x)