In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(GraphAttentionLayer, self).__init__()
        
        # Transformation matrices
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        
        # Attention mechanism parameters
        self.a = nn.Parameter(torch.Tensor(out_dim * 2, 1))
        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, h, adj):
        """
        h: Node features, shape [N, in_dim]
        adj: Adjacency matrix, shape [N, N]
        """
        
        # Transform node features
        h_prime = self.W(h)  # Shape [N, out_dim]

        # Compute attention scores
        N = h_prime.size(0)
        attention_input = torch.cat([h_prime.repeat(1, N).view(N * N, -1), h_prime.repeat(N, 1)], dim=1).view(N, N, 2 * h_prime.size(1))
        e = self.leakyrelu(torch.matmul(attention_input, self.a).squeeze(2))

        # Zero out attention scores for non-neighbors
        e = e.masked_fill(adj == 0, float('-inf'))
        
        # Compute attention weights
        attention = F.softmax(e, dim=1)

        # Compute new node features
        h_new = torch.matmul(attention, h_prime)

        return h_new

# Example usage:
# Initialize an adjacency matrix and node features
N = 5  # Number of nodes
in_dim = 3
adj = torch.randint(0, 2, (N, N))
node_features = torch.randn((N, in_dim))

# Initialize and apply GAT layer
gat_layer = GraphAttentionLayer(in_dim, 8)
new_features = gat_layer(node_features, adj)
print(new_features)
