## Graph Attention Networks v2 (GATv2)
GATv2s work on graph data similar to GAT. A graph consists of nodes and edges connecting nodes. For example, in Cora dataset the nodes are research papers and the edges are citations that connect the papers

The GATv2 operator fixes the static attention problem of the standard GAT. Static attention is when the attention to the key nodes has the same rank (order) for any query node.

The GATv2 operator fixes the static attention problem of the standard GAT. Static attention is when the attention to the key nodes has the same rank (order) for any query node. GAT computes attention from query node $i$ to key node $j$ as,
$$e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^\top [\mathbf{W}\vec{h_i}\Vert\mathbf{W}\vec{h_j}]\right) \\ = \text{LeakyReLU}\left(\mathbf{a}_1^\top\mathbf{W}\vec{h_i}+\mathbf{a}_2^\top\mathbf{W}\vec{h_j}\right)$$


GATv2 allows dynamic attention by changing the attention mechanism

$$e_{ij} = \mathbf{a}^\top \text{LeakyReLU} \left(\mathbf{W} [h_i \Vert h_j] \right) \\ = \mathbf{a}^\top \text{LeakyReLU} \left(\mathbf{W}_i \vec{h_i} + \mathbf{W}_j \vec{h_j} \right)$$

*   **$e_{ij}$**: The **unnormalized** attention coefficient between node $i$ (query) and node $j$ (neighbor).
*   **$\mathbf{a}$**: A learnable **weight vector** (a shared attention mechanism parameter).
*   **$\mathbf{W}$**: A learnable **weight matrix** that transforms input features.
*   **$\mathbf{W}_i, \mathbf{W}_j$**: Resulting **sub-matrices** from decomposing $\mathbf{W}$ when distributed across the concatenated vectors (as shown in the second line of the image).
*   **$h_i, h_j$**: The input **feature vectors** for nodes $i$ and $j$, respectively.
*   **$\Vert$**: The **concatenation** operation combining the two feature vectors.
*   **$\text{LeakyReLU}$**: The **non-linear activation function** applied element-wise.
*   **$\top$**: The **transpose** operation.

## Dataset
Train the model on `dataset/drugdata`

In [1]:
import torch
from torch import nn

In [3]:
class GraphAttentionV2Layer(nn.Module):
    def __init__(self, in_featured:int, out_features:int, n_heads:int, is_concat:bool = True, dropout:float = 0.6, leaky_relu_negative_slope:float = 0.2, share_weights:bool = False):

        super().__init__()

        self.is_concat = is_concat
        self.n_heads = n_heads
        self.share_weights = share_weights

        # calculate the number of dimension per head
        if is_concat:
            assert out_features % n_heads == 0
            # if we are concatenating the multiple heads
            self.n_hidden = out_features // n_heads
        else:
            # if we are averaging the multipls heads
            self.n_heads = out_features

        # Linear layer for initial source transformation; i.e. to transform the source node embeddings before self-attention 
        self.linear_1 = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

        # if `share_weights` is true, the same layer is used for the target nodes
        if share_weights:
            self.linear_r = self.linear_1
        else:
            self.linear_r = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)

        # linear layer to compute attention score e_ij
        self.attn = nn.Linear(self.n_hidden, 1, bias=False)

        # the activation for the attention score e_ij
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)

        # softmax to compute attention a_ij
        self.softmax = nn.Softmax(dim=1)

        # dropout layer to be applied for attention
        self.dropout = nn.Dropout(dropout)

    def Forward(self, h:torch.Tensor, adj_mat:torch.Tensor):
        """
        h: input node embeddings of shape [n_nodes, in_features] .
        adj_mat: adjacency matrix of shape [n_nodes, n_nodes, n_heads]. We use shape [n_nodes, n_nodes, 1] since the adjacency is the same for each head. Adjacency matrix represent the edges (or connections) among nodes. adj_mat[i][j] is True if there is an edge from node i to node j .
    """
        # number of nodes
        n_nodes = h.shape[0]

        # The initial transformations, for each head. We do two linear transformations and then split it up for each head. 
        g_l = self.linear_1(h).view(n_nodes, self.n_heads, self.n_hidden)
        h_r = self.linear_r(h).view(n_nodes, self.n_heads, self.n_hidden)

        ## Calculate attention score
        g_l_repeat = g_l.repeat(n_nodes, 1, 1)

        g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)

        # now add the two tensors to get all concat
        g_sum = g_l_repeat + g_r_repeat_interleave

        # reshapes so that g_sum[i, j] is gl_i + gr_j
        g_sum = g_sum(n_nodes, n_nodes, self.n_heads, self.n_hidden)

        # calculate e_ij
        e = self.attn(self.activation(g_sum))

        # remove the last dimension of size 1
        e = e.squeeze(-1)

        # validating adjacency matrix shape: [n_nodes, n_nodes, n_heads] or[n_nodes, n_nodes, 1]
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads

        # Mask e_ij based on adjacency matrix e_ij is set to -inf if there is no edge from i to j
        e = e.masked_fill(adj_mat == 0, float('-inf'))

        # normalized attention scores (or coefficients)
        a = self.softmax(e)

        # apply dropout regularization
        a = self.dropout(a)

        # calculate final output for each head
        attn_res = torch.einsum('ijh, jhf->ihf', a, g_r)

        # concatenate the heads
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            # take the mean of the heads
            return attn_res.mean(dim=1)