In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import numpy as np


class GraphAttentionLayer(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_rate=0.5, alpha=0.2):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.alpha = alpha
        self.weights = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
        self.bias = nn.Parameter(torch.FloatTensor(output_dim))

        # Attention mechanisms
        self.attention = nn.Parameter(torch.FloatTensor(2 * output_dim, 1))

        # Initialize parameters
        nn.init.xavier_uniform_(self.weights.data)
        nn.init.xavier_uniform_(self.attention.data)
        nn.init.constant_(self.bias.data, 0.0)

    def forward(self, x, adjacency):
        x = torch.matmul(x, self.weights)
        x = torch.matmul(adjacency, x)

        # Attention mechanism
        num_nodes = x.size()[0]
        attention_input = torch.cat([x.repeat(1, num_nodes).view(num_nodes * num_nodes, -1),
                                     x.repeat(num_nodes, 1)], dim=1).view(num_nodes, -1, 2 * x.size(1))

        attention_weights = F.leaky_relu(torch.matmul(attention_input, self.attention), negative_slope=self.alpha)
        attention_weights = F.softmax(attention_weights, dim=1)
        attention_weights = self.dropout(attention_weights)

        x = torch.matmul(attention_weights.transpose(1, 2), x)
        x = x.squeeze()
        x = x + self.bias
        return x


class GraphAttentionNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, dropout_rate=0.5, alpha=0.2):
        super(GraphAttentionNetwork, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.attentions = nn.ModuleList()
        for _ in range(num_heads):
            self.attentions.append(GraphAttentionLayer(input_dim, hidden_dim, dropout_rate, alpha))

        self.out_att = GraphAttentionLayer(hidden_dim * num_heads, output_dim, dropout_rate, alpha)

    def forward(self, x, adjacency):
        x = torch.cat([att(x, adjacency) for att in self.attentions], dim=1)
        x = self.out_att(x, adjacency)
        return x


# Create a sample graph using NetworkX
graph = nx.karate_club_graph()

# Generate adjacency matrix
adjacency = nx.adjacency_matrix(graph)
adjacency = torch.tensor(adjacency.todense(), dtype=torch.float32)

# Generate node features
features = np.eye(graph.number_of_nodes(), dtype=np.float32)
features = torch.tensor(features, dtype=torch.float32)

# Define model parameters
input_dim = features.shape[1]
hidden_dim = 8
output_dim = 2
num_heads = 2

# Create GAT model
model = GraphAttentionNetwork(input_dim, hidden_dim, output_dim, num_heads)

# Perform forward pass
output = model(features, adjacency)
print("Output shape:", output.shape)


Output shape: torch.Size([34, 2])
