In [2]:
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

## Self Attention  
https://medium.com/@wangdk93/implement-self-attention-and-cross-attention-in-pytorch-1f1a366c9d4b

https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention

In [8]:
class SelfAttention(nn.Module):
    def __init__(self,input_dim,d_qk,d_v):
        super(SelfAttention,self).__init__()
        self.input_dim = input_dim
        self.d_qk = d_qk
        self.d_v = d_v

        #Q,K,V
        self.query = nn.Linear(input_dim,d_qk)    # [B,L, d_qk]
        self.key = nn.Linear(input_dim,d_qk)    # [B,L, d_qk]
        self.value = nn.Linear(input_dim,d_v)     # [B,L, d_v]
        
        # Attention
        self.softmax = nn.Softmax(dim = -1)  # SoftMax on the last axis 

    def forward(self,x):   #x.shape = [B,L, input_dim]
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)

        scores = torch.bmm(queries, keys.transpose(1,2))/(self.input_dim**0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention,values)

        return(weighted)

 Test : 

In [13]:
B,N,L = 32, 43, 8 # [ Batch-size, Number of Station within the sub-graph, Historical Length of each sequence or The Embedding Representation of the historical length ] 
d_qk = 2*L
d_v = 4

x = torch.randn(B,N,L)


selfattention = SelfAttention(L,d_qk,d_v)
selfattention(x).shape

torch.Size([32, 43, 4])

## Multi-Head Self-Attention  
More efficient in terms of parallel computation, cause each head can be processed independently.

In [16]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self,n_head,input_dim, d_qk, d_v):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(input_dim,d_qk,d_v) for _ in range(n_head)])

    def forward(self,x):
        return torch.cat([h(x) for h in self.heads],dim = -1) 

Test : 

In [17]:
n_head = 3 
mha = MultiHeadAttentionWrapper(n_head,L,d_qk,d_v)

mha(x).shape

torch.Size([32, 43, 12])